mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
92 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 764404a27e | |||
| 8b9451b585 | |||
| ab4903e752 | |||
| 538cea6dbc | |||
| 5cd3572713 | |||
| 3399513e5e | |||
| 32fc4015ee | |||
| cc72c813bf | |||
| 606f31a86e | |||
| 4933c9dcc7 | |||
| 7e25385024 | |||
| cc70bff74d | |||
| 9f50913b9c | |||
| 4eb7694d47 | |||
| edb5559b5b | |||
| 552ec76195 | |||
| e75340b473 | |||
| 2a4c223ec7 | |||
| 1ee4d84f07 | |||
| 6bd40ca219 | |||
| b879cf3d04 | |||
| bd9e5c1a64 | |||
| 9271a0c900 | |||
| af2f044f5a | |||
| 0caba222ef | |||
| 6d73f5bfe6 | |||
| ef8f40c21b | |||
| 0232879245 | |||
| 2726b4e865 | |||
| e126d35249 | |||
| d7ae8cd699 | |||
| 2f96d8bf76 | |||
| e129c71b4f | |||
| a02d70389d | |||
| 0d4922ce49 | |||
| eaeff78924 | |||
| e2f3982e2c | |||
| a73ac2bdbb | |||
| 95de732e55 | |||
| b2383236ca | |||
| 4b98cc25c8 | |||
| 90780c4de8 | |||
| 6f6e046c53 | |||
| 8cd64eaad1 | |||
| e620395416 | |||
| 0fbcbcdb2e | |||
| 674f5dfd75 | |||
| 7d430c8067 | |||
| 5f114c1d74 | |||
| ad01ef19f4 | |||
| 59e8f4572c | |||
| 97e91698fb | |||
| af0294198a | |||
| 421fdcce96 | |||
| bb63ad9715 | |||
| 3c90a79c57 | |||
| 8e29c530ed | |||
| b573b7a052 | |||
| 926184110b | |||
| bf8ede852d | |||
| f73db4394b | |||
| bff91f9927 | |||
| 6d726266fd | |||
| 2962330bb1 | |||
| 067993bb11 | |||
| e4dd00c8f5 | |||
| e714ff22e2 | |||
| 3bbd161cfd | |||
| 6d7be63f59 | |||
| b9d0dfb9a2 | |||
| dce483060f | |||
| c32b9182d9 | |||
| a4d4ef0e7f | |||
| 9a5c96b2b1 | |||
| 0a6ca58299 | |||
| 688195fc46 | |||
| 99eb0bbafc | |||
| 16de8b3f19 | |||
| 580008663b | |||
| 52c424c5eb | |||
| 836195e59c | |||
| be09a59e05 | |||
| 373a169bd2 | |||
| 00536c6c5b | |||
| cdd3a859ef | |||
| 5276fc0d6f | |||
| 6a2882f978 | |||
| 8874547353 | |||
| 2864caad80 | |||
| d998660aa1 | |||
| 7e5f3b35e9 | |||
| 01fea7c407 |
@@ -18,11 +18,6 @@ name: Documentation
|
||||
on:
|
||||
# Allows running this workflow manually from the Actions tab
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
version:
|
||||
description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build'
|
||||
required: false
|
||||
type: string
|
||||
|
||||
# Triggers the workflow on push events to main for the docs folder
|
||||
push:
|
||||
@@ -59,13 +54,7 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: lerobot
|
||||
additional_args: >-
|
||||
--not_python_module
|
||||
${{
|
||||
(github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) ||
|
||||
(inputs.version != '' && format('--version {0}', inputs.version)) ||
|
||||
''
|
||||
}}
|
||||
additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }}
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ You can contribute in many ways:
|
||||
- **Documentation:** Improve examples, guides, and docstrings.
|
||||
- **Feedback:** Submit tickets related to bugs or desired new features.
|
||||
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f).
|
||||
If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw).
|
||||
|
||||
## Development Setup
|
||||
|
||||
|
||||
@@ -128,7 +128,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
## Resources
|
||||
|
||||
- **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API.
|
||||
- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players.
|
||||
- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community.
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.portaudio import PortAudioMicrophone, PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
microphones_configs: dict[str, MicrophoneConfig],
|
||||
audio_chunks_number: int,
|
||||
audio_chunks_duration: float,
|
||||
repetitions: int,
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/audio_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
microphones = make_microphones_from_configs(microphones_configs)
|
||||
|
||||
# Connect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
all_audio_chunks = []
|
||||
for i in range(repetitions):
|
||||
print(f"Repetition {i + 1}/{repetitions}...")
|
||||
|
||||
# Create audio chunks
|
||||
audio_chunks = {}
|
||||
for microphone_key in microphones:
|
||||
audio_chunks.update({microphone_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
microphones,
|
||||
output_files=[
|
||||
recording_dir / f"{microphone_key}_recording_{i}.wav" for microphone_key in microphones
|
||||
],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
for j in range(audio_chunks_number):
|
||||
precise_sleep(audio_chunks_duration)
|
||||
|
||||
for microphone_key, microphone in microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
print(f"{microphone_key} - repetition {i} - chunk {j} - samples {audio_chunk.shape[0]}")
|
||||
audio_chunks[microphone_key].append(audio_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone_key in microphones:
|
||||
audio_chunks[microphone_key] = np.concatenate(audio_chunks[microphone_key], axis=0)
|
||||
|
||||
all_audio_chunks.append(audio_chunks)
|
||||
|
||||
# Disconnect microphones
|
||||
for microphone in microphones.values():
|
||||
microphone.disconnect()
|
||||
|
||||
# Compute statistics
|
||||
cmap = plt.get_cmap("tab10")
|
||||
_, ax = plt.subplots(nrows=repetitions, ncols=len(microphones))
|
||||
chunk_length = np.zeros((repetitions, len(microphones)))
|
||||
record_length = np.zeros((repetitions, len(microphones)))
|
||||
for i in range(repetitions):
|
||||
for j, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
# Get recorded audio chunks
|
||||
recorded_audio_chunks = all_audio_chunks[i][microphone_key]
|
||||
|
||||
# Load recorded file
|
||||
recorded_data, _ = read(recording_dir / f"{microphone_key}_recording_{i}.wav")
|
||||
if recorded_data.ndim == 1:
|
||||
recorded_data = np.expand_dims(recorded_data, axis=1)
|
||||
|
||||
record_length[i, j] = recorded_data.shape[0]
|
||||
chunk_length[i, j] = recorded_audio_chunks.shape[0]
|
||||
|
||||
for k, (chunk_data, record_data) in enumerate(
|
||||
zip(recorded_audio_chunks.T, recorded_data.T, strict=False)
|
||||
):
|
||||
# Plot audio chunks and recorded data
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(chunk_data)) / microphone.sample_rate,
|
||||
chunk_data,
|
||||
label=f"audio chunks - channel {k}",
|
||||
color=cmap(2 * k),
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
record_data,
|
||||
label=f"recorded data - channel {k}",
|
||||
linestyle="dashed",
|
||||
color=cmap(2 * k + 1),
|
||||
)
|
||||
|
||||
# Plot absolute difference (errors should be located at the end of the recordings)
|
||||
if recorded_data.shape[0] - recorded_audio_chunks.shape[0] > 0:
|
||||
chunk_data = np.append(
|
||||
chunk_data, np.zeros(int(recorded_data.shape[0] - recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
else:
|
||||
record_data = np.append(
|
||||
record_data, np.zeros(int(-recorded_data.shape[0] + recorded_audio_chunks.shape[0]))
|
||||
)
|
||||
ax[i, j].plot(
|
||||
np.arange(0, len(record_data)) / microphone.sample_rate,
|
||||
np.abs(chunk_data - record_data),
|
||||
label=f"differences - channel {k}",
|
||||
color="red",
|
||||
linestyle="dotted",
|
||||
)
|
||||
ax[i, j].set_title(f"{microphone_key} - repetition {i}")
|
||||
ax[i, j].legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
# Print statistics
|
||||
differences = record_length - chunk_length
|
||||
for i, (microphone_key, microphone) in enumerate(microphones.items()):
|
||||
print(
|
||||
f"Average recorded duration for {microphone_key} : {np.mean(record_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(
|
||||
f"Average chunk duration for {microphone_key} : {np.mean(chunk_length[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
print(f"Average difference for {microphone_key} : {np.mean(differences[:, i]):.3f} samples")
|
||||
print(
|
||||
f"Average difference for {microphone_key} : {np.mean(differences[:, i]) / microphone.sample_rate:.3f} seconds"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--microphones_indices",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[microphone["index"] for microphone in PortAudioMicrophone.find_microphones()],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_sample_rate",
|
||||
type=float,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--microphones_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[None] * len(PortAudioMicrophone.find_microphones()),
|
||||
)
|
||||
parser.add_argument("--audio_chunks_number", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--audio_chunks_duration",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repetitions",
|
||||
type=int,
|
||||
default=2,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["microphones_configs"] = {}
|
||||
for index, sample_rate, channels in zip(
|
||||
args["microphones_indices"],
|
||||
args["microphones_sample_rate"],
|
||||
args["microphones_channels"],
|
||||
strict=False,
|
||||
):
|
||||
microphone_config = PortAudioMicrophoneConfig(
|
||||
microphone_index=index,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["microphones_configs"].update({f"microphone_{index}": microphone_config})
|
||||
args.pop("microphones_indices")
|
||||
args.pop("microphones_sample_rate")
|
||||
args.pop("microphones_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -0,0 +1,136 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
from lerobot.microphones.anyskin import AnyskinSensorConfig
|
||||
from lerobot.microphones.configs import MicrophoneConfig
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
make_microphones_from_configs,
|
||||
)
|
||||
from lerobot.utils.robot_utils import (
|
||||
precise_sleep,
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
sensors_configs: dict[str, MicrophoneConfig],
|
||||
multiprocessing: bool = False,
|
||||
):
|
||||
recording_dir = Path("outputs/tactile_benchmark")
|
||||
recording_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create microphones
|
||||
sensors = make_microphones_from_configs(sensors_configs)
|
||||
|
||||
# Connect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.connect()
|
||||
|
||||
# Create audio chunks
|
||||
data_chunks = {}
|
||||
for sensor_key in sensors:
|
||||
data_chunks.update({sensor_key: []})
|
||||
|
||||
# Start recording
|
||||
async_microphones_start_recording(
|
||||
sensors,
|
||||
output_files=[recording_dir / f"{sensor_key}_recording.wav" for sensor_key in sensors],
|
||||
multiprocessing=multiprocessing,
|
||||
)
|
||||
|
||||
# Record audio chunks
|
||||
precise_sleep(10.0)
|
||||
|
||||
for sensor_key, sensor in sensors.items():
|
||||
data_chunk = sensor.read()
|
||||
print(f"{sensor_key} - samples {data_chunk.shape[0]}")
|
||||
data_chunks[sensor_key].append(data_chunk)
|
||||
|
||||
# Stop recording
|
||||
async_microphones_stop_recording(sensors)
|
||||
|
||||
for sensor_key in sensors:
|
||||
data_chunks[sensor_key] = np.concatenate(data_chunks[sensor_key], axis=0)
|
||||
|
||||
# Disconnect microphones
|
||||
for sensor in sensors.values():
|
||||
sensor.disconnect()
|
||||
|
||||
for sensor_key in sensors:
|
||||
data, sample_rate = sf.read(recording_dir / f"{sensor_key}_recording.wav")
|
||||
print(f"{sensor_key} - samples {data.shape[0]}")
|
||||
print(f"{sensor_key} - sample rate {sample_rate}")
|
||||
print(f"{sensor_key} - data {data}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--sensors_ports",
|
||||
type=str,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_baud_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_sample_rate",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sensors_channels",
|
||||
type=int,
|
||||
nargs="+",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multiprocessing",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
args["sensors_configs"] = {}
|
||||
for port, baud_rate, sample_rate, channels in zip(
|
||||
args["sensors_ports"],
|
||||
args["sensors_baud_rate"],
|
||||
args["sensors_sample_rate"],
|
||||
args["sensors_channels"],
|
||||
strict=False,
|
||||
):
|
||||
channels = [1, 2, 3, 4, 5]
|
||||
sensor_config = AnyskinSensorConfig(
|
||||
sensor_port=port,
|
||||
baud_rate=baud_rate,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
args["sensors_configs"].update({f"sensor_{port}": sensor_config})
|
||||
args.pop("sensors_ports")
|
||||
args.pop("sensors_baud_rate")
|
||||
args.pop("sensors_sample_rate")
|
||||
args.pop("sensors_channels")
|
||||
|
||||
main(**args)
|
||||
@@ -195,7 +195,6 @@ client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
|
||||
@@ -43,12 +43,13 @@ def main():
|
||||
keyboard.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
init_rerun(session_name="lekiwi_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -69,7 +70,7 @@ def main():
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
log_rerun_data(observation=observation, action=action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -89,12 +89,13 @@ def main():
|
||||
teleop_device.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
init_rerun(session_name="phone_so100_teleop", robot=robot, reset_time=True)
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -111,7 +112,7 @@ def main():
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
log_rerun_data(observation=phone_obs, action=joint_action, log_time=time.perf_counter() - start)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -94,9 +94,10 @@ def main():
|
||||
leader.connect()
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
init_rerun(session_name="so100_so100_EE_teleop", robot=follower, reset_time=True)
|
||||
|
||||
print("Starting teleop loop...")
|
||||
start = time.perf_counter()
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
@@ -116,7 +117,9 @@ def main():
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
log_rerun_data(
|
||||
observation=leader_ee_act, action=follower_joints_act, log_time=time.perf_counter() - start
|
||||
)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ def main():
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
client_device="cpu",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
|
||||
+3
-1
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.4"
|
||||
version = "0.4.3"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
dynamic = ["readme"]
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -147,6 +147,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci
|
||||
# Features
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
||||
peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"]
|
||||
audio = ["sounddevice>=0.5.1,<0.6.0", "soundfile>=0.13.1,<0.14.0", "librosa>=0.11.0,<0.12.0", "torchaudio>=2.6.0,<2.10.0"]
|
||||
|
||||
# Development
|
||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"]
|
||||
@@ -175,6 +176,7 @@ all = [
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[audio]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
"lerobot[video_benchmark]",
|
||||
|
||||
@@ -29,6 +29,7 @@ Example:
|
||||
print(lerobot.available_policies_per_env)
|
||||
print(lerobot.available_robots)
|
||||
print(lerobot.available_cameras)
|
||||
print(lerobot.available_microphones)
|
||||
print(lerobot.available_motors)
|
||||
```
|
||||
|
||||
@@ -174,6 +175,13 @@ available_cameras = [
|
||||
"intelrealsense",
|
||||
]
|
||||
|
||||
# lists all available microphones from `lerobot/microphones`
|
||||
available_microphones = [
|
||||
"portaudio",
|
||||
"touchlab",
|
||||
"anyskin",
|
||||
]
|
||||
|
||||
# lists all available motors from `lerobot/motors`
|
||||
available_motors = [
|
||||
"dynamixel",
|
||||
|
||||
@@ -126,12 +126,6 @@ class RobotClientConfig:
|
||||
|
||||
# Device configuration
|
||||
policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"})
|
||||
client_device: str = field(
|
||||
default="cpu",
|
||||
metadata={
|
||||
"help": "Device to move actions to after receiving from server (e.g., for downstream planners)"
|
||||
},
|
||||
)
|
||||
|
||||
# Control behavior configuration
|
||||
chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"})
|
||||
@@ -167,9 +161,6 @@ class RobotClientConfig:
|
||||
if not self.policy_device:
|
||||
raise ValueError("policy_device cannot be empty")
|
||||
|
||||
if not self.client_device:
|
||||
raise ValueError("client_device cannot be empty")
|
||||
|
||||
if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1:
|
||||
raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}")
|
||||
|
||||
@@ -193,7 +184,6 @@ class RobotClientConfig:
|
||||
"policy_type": self.policy_type,
|
||||
"pretrained_name_or_path": self.pretrained_name_or_path,
|
||||
"policy_device": self.policy_device,
|
||||
"client_device": self.client_device,
|
||||
"chunk_size_threshold": self.chunk_size_threshold,
|
||||
"fps": self.fps,
|
||||
"actions_per_chunk": self.actions_per_chunk,
|
||||
|
||||
@@ -18,7 +18,6 @@ import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,8 +39,8 @@ from lerobot.utils.utils import init_logging
|
||||
|
||||
Action = torch.Tensor
|
||||
|
||||
# observation as received from the robot (can be numpy arrays, floats, etc.)
|
||||
RawObservation = dict[str, Any]
|
||||
# observation as received from the robot
|
||||
RawObservation = dict[str, torch.Tensor]
|
||||
|
||||
# observation as those recorded in LeRobot dataset (keys are different)
|
||||
LeRobotObservation = dict[str, torch.Tensor]
|
||||
|
||||
@@ -381,8 +381,6 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer):
|
||||
action_tensor = torch.stack(processed_actions, dim=1).squeeze(0)
|
||||
self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}")
|
||||
|
||||
action_tensor = action_tensor.detach().cpu()
|
||||
|
||||
"""5. Convert to TimedAction list"""
|
||||
action_chunk = self._time_action_chunk(
|
||||
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep()
|
||||
|
||||
@@ -25,7 +25,6 @@ python src/lerobot/async_inference/robot_client.py \
|
||||
--policy_type=act \
|
||||
--pretrained_name_or_path=user/model \
|
||||
--policy_device=mps \
|
||||
--client_device=cpu \
|
||||
--actions_per_chunk=50 \
|
||||
--chunk_size_threshold=0.5 \
|
||||
--aggregate_fn_name=weighted_average \
|
||||
@@ -41,7 +40,6 @@ from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import grpc
|
||||
@@ -49,6 +47,10 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -286,21 +288,6 @@ class RobotClient:
|
||||
timed_actions = pickle.loads(actions_chunk.data) # nosec
|
||||
deserialize_time = time.perf_counter() - deserialize_start
|
||||
|
||||
# Log device type of received actions
|
||||
if len(timed_actions) > 0:
|
||||
received_device = timed_actions[0].get_action().device.type
|
||||
self.logger.debug(f"Received actions on device: {received_device}")
|
||||
|
||||
# Move actions to client_device (e.g., for downstream planners that need GPU)
|
||||
client_device = self.config.client_device
|
||||
if client_device != "cpu":
|
||||
for timed_action in timed_actions:
|
||||
if timed_action.get_action().device.type != client_device:
|
||||
timed_action.action = timed_action.get_action().to(client_device)
|
||||
self.logger.debug(f"Converted actions to device: {client_device}")
|
||||
else:
|
||||
self.logger.debug(f"Actions kept on device: {client_device}")
|
||||
|
||||
self.action_chunk_size = max(self.action_chunk_size, len(timed_actions))
|
||||
|
||||
# Calculate network latency if we have matching observations
|
||||
@@ -367,7 +354,7 @@ class RobotClient:
|
||||
action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)}
|
||||
return action
|
||||
|
||||
def control_loop_action(self, verbose: bool = False) -> dict[str, Any]:
|
||||
def control_loop_action(self, verbose: bool = False) -> RobotAction:
|
||||
"""Reading and performing actions in local queue"""
|
||||
|
||||
# Lock only for queue operations
|
||||
|
||||
@@ -105,16 +105,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
def observation_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list | None: # type: ignore[type-arg]
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
Unlike observation_delta_indices which applies to ALL observations,
|
||||
this only applies to image observations (keys starting with observation.images).
|
||||
Default returns None. Override in subclass to enable.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None: # type: ignore[type-arg] #TODO: No implementation
|
||||
@@ -161,6 +151,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||
if not self.input_features:
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if not self.output_features:
|
||||
|
||||
@@ -20,6 +20,7 @@ from enum import Enum
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
AUDIO = "AUDIO"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
@@ -26,6 +26,8 @@ import tqdm
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -41,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
|
||||
|
||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||
@@ -110,6 +112,7 @@ def update_meta_data(
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
):
|
||||
"""Updates metadata DataFrame with new chunk, file, and timestamp indices.
|
||||
|
||||
@@ -122,7 +125,7 @@ def update_meta_data(
|
||||
meta_idx: Dictionary containing current metadata chunk and file indices.
|
||||
data_idx: Dictionary containing current data chunk and file indices.
|
||||
videos_idx: Dictionary containing current video indices and timestamps.
|
||||
|
||||
audios_idx: Dictionary containing current audio indices and timestamps.
|
||||
Returns:
|
||||
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
|
||||
"""
|
||||
@@ -180,6 +183,36 @@ def update_meta_data(
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
# Store original audio file indices before updating
|
||||
orig_chunk_col = f"audio/{key}/chunk_index"
|
||||
orig_file_col = f"audio/{key}/file_index"
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = audio_idx["chunk"]
|
||||
df[orig_file_col] = audio_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
src_to_offset = audio_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"audio/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"audio/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[f"audio/{key}/from_timestamp"] = (
|
||||
df[f"audio/{key}/from_timestamp"] + audio_idx["latest_duration"]
|
||||
)
|
||||
df[f"audio/{key}/to_timestamp"] = df[f"audio/{key}/to_timestamp"] + audio_idx["latest_duration"]
|
||||
|
||||
# Clean up temporary columns
|
||||
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||
|
||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
|
||||
@@ -194,6 +227,7 @@ def aggregate_datasets(
|
||||
aggr_root: Path | None = None,
|
||||
data_files_size_in_mb: float | None = None,
|
||||
video_files_size_in_mb: float | None = None,
|
||||
audio_files_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""Aggregates multiple LeRobot datasets into a single unified dataset.
|
||||
@@ -211,6 +245,7 @@ def aggregate_datasets(
|
||||
aggr_root: Optional root path for the aggregated dataset.
|
||||
data_files_size_in_mb: Maximum size for data files in MB (defaults to DEFAULT_DATA_FILE_SIZE_IN_MB)
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
"""
|
||||
logging.info("Start aggregate_datasets")
|
||||
@@ -219,6 +254,8 @@ def aggregate_datasets(
|
||||
data_files_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_files_size_in_mb is None:
|
||||
video_files_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_files_size_in_mb is None:
|
||||
audio_files_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -231,6 +268,7 @@ def aggregate_datasets(
|
||||
)
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
audio_keys = [key for key in features if features[key]["dtype"] == "audio"]
|
||||
|
||||
dst_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=aggr_repo_id,
|
||||
@@ -242,6 +280,7 @@ def aggregate_datasets(
|
||||
chunks_size=chunk_size,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
video_files_size_in_mb=video_files_size_in_mb,
|
||||
audio_files_size_in_mb=audio_files_size_in_mb,
|
||||
)
|
||||
|
||||
logging.info("Find all tasks")
|
||||
@@ -253,14 +292,18 @@ def aggregate_datasets(
|
||||
videos_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||
}
|
||||
audios_idx = {
|
||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in audio_keys
|
||||
}
|
||||
|
||||
dst_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chunk_size)
|
||||
audios_idx = aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size)
|
||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx)
|
||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx)
|
||||
|
||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||
@@ -328,7 +371,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="video")
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
@@ -367,7 +410,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
@@ -382,6 +425,101 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
return videos_idx
|
||||
|
||||
|
||||
def aggregate_audio(src_meta, dst_meta, audios_idx, audio_files_size_in_mb, chunk_size):
|
||||
"""Aggregates audio files from a source dataset into the destination dataset.
|
||||
|
||||
Handles audio file concatenation and rotation based on file size limits.
|
||||
Creates new audio files when size limits are exceeded.
|
||||
|
||||
Args:
|
||||
src_meta: Source dataset metadata.
|
||||
dst_meta: Destination dataset metadata.
|
||||
audio_idx: Dictionary tracking audio chunk and file indices.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB (defaults to DEFAULT_AUDIO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated audio_idx with current chunk and file indices.
|
||||
"""
|
||||
for key in audios_idx:
|
||||
audios_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
audios_idx[key]["src_to_offset"] = {}
|
||||
|
||||
for key, audio_idx in audios_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
(chunk, file)
|
||||
for chunk, file in zip(
|
||||
src_meta.episodes[f"audio/{key}/chunk_index"],
|
||||
src_meta.episodes[f"audio/{key}/file_index"],
|
||||
strict=False,
|
||||
)
|
||||
}
|
||||
unique_chunk_file_pairs = sorted(unique_chunk_file_pairs)
|
||||
|
||||
chunk_idx = audio_idx["chunk"]
|
||||
file_idx = audio_idx["file"]
|
||||
current_offset = audio_idx["latest_duration"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
src_path = src_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
file_index=src_file_idx,
|
||||
)
|
||||
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
|
||||
src_duration = get_media_duration_in_s(src_path, media_type="audio")
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
src_size = get_file_size_in_mb(src_path)
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= audio_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_path = dst_meta.root / DEFAULT_AUDIO_PATH.format(
|
||||
audio_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
file_index=file_idx,
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
audios_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
concatenate_media_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
|
||||
audios_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
audios_idx[key]["chunk"] = chunk_idx
|
||||
audios_idx[key]["file"] = file_idx
|
||||
|
||||
return audios_idx
|
||||
|
||||
|
||||
def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_size):
|
||||
"""Aggregates data chunks from a source dataset into the destination dataset.
|
||||
|
||||
@@ -436,7 +574,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
return data_idx
|
||||
|
||||
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, audios_idx):
|
||||
"""Aggregates metadata from a source dataset into the destination dataset.
|
||||
|
||||
Reads source metadata files, updates all indices and timestamps,
|
||||
@@ -448,6 +586,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx: Dictionary tracking metadata chunk and file indices.
|
||||
data_idx: Dictionary tracking data chunk and file indices.
|
||||
videos_idx: Dictionary tracking video indices and timestamps.
|
||||
audios_idx: Dictionary tracking audio indices and timestamps.
|
||||
|
||||
Returns:
|
||||
dict: Updated meta_idx with current chunk and file indices.
|
||||
@@ -471,6 +610,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
meta_idx,
|
||||
data_idx,
|
||||
videos_idx,
|
||||
audios_idx,
|
||||
)
|
||||
|
||||
meta_idx = append_or_create_parquet_file(
|
||||
@@ -487,7 +627,8 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
||||
# Increment latest_duration by the total duration added from this source dataset
|
||||
for k in videos_idx:
|
||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||
|
||||
for k in audios_idx:
|
||||
audios_idx[k]["latest_duration"] += audios_idx[k]["episode_duration"]
|
||||
return meta_idx
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import av
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchcodec
|
||||
from numpy import ceil
|
||||
|
||||
CHANNELS_LAYOUTS_MAPPING = {
|
||||
1: "mono",
|
||||
2: "stereo",
|
||||
3: "2.1",
|
||||
4: "3.1",
|
||||
5: "4.1",
|
||||
6: "5.1",
|
||||
7: "6.1",
|
||||
8: "7.1",
|
||||
16: "hexadecagonal",
|
||||
24: "22.2",
|
||||
}
|
||||
|
||||
|
||||
def decode_audio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
backend: str | None = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes audio using the specified backend.
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file.
|
||||
timestamps (list[float]): List of (starting) timestamps to extract audio chunks.
|
||||
duration (float): Duration of the audio chunks in seconds.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded audio chunks.
|
||||
|
||||
Currently supports torchaudio.
|
||||
"""
|
||||
if backend == "torchcodec":
|
||||
return decode_audio_torchcodec(audio_path, timestamps, duration, start_time_s)
|
||||
elif backend == "torchaudio":
|
||||
return decode_audio_torchaudio(audio_path, timestamps, duration, start_time_s)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_audio_torchcodec(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_decoder = torchcodec.decoders.AudioDecoder(audio_path)
|
||||
audio_sample_rate = audio_decoder.metadata.sample_rate
|
||||
audio_channels = audio_decoder.metadata.num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
current_audio_chunk = audio_decoder.get_samples_played_in_range(
|
||||
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||
)
|
||||
|
||||
current_audio_chunk_data = current_audio_chunk.data
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def decode_audio_torchaudio(
|
||||
audio_path: Path | str,
|
||||
timestamps: list[float],
|
||||
duration: float,
|
||||
start_time_s: float | None = 0.0,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# TODO(CarolinePascal) : add channels selection
|
||||
audio_path = str(audio_path)
|
||||
|
||||
reader = torchaudio.io.StreamReader(src=audio_path)
|
||||
audio_sample_rate = reader.get_src_stream_info(reader.default_audio_stream).sample_rate
|
||||
audio_channels = reader.get_src_stream_info(reader.default_audio_stream).num_channels
|
||||
# TODO(CarolinePascal) : assert ts < total record duration
|
||||
|
||||
# TODO(CarolinePascal) : sort timestamps ?
|
||||
|
||||
reader.add_basic_audio_stream(
|
||||
frames_per_chunk=int(ceil(duration * audio_sample_rate)), # Too much is better than not enough
|
||||
buffer_chunk_size=-1, # No dropping frames
|
||||
format="fltp", # Format as float32
|
||||
)
|
||||
|
||||
audio_chunks = []
|
||||
timestamps = [
|
||||
timestamp + start_time_s for timestamp in timestamps
|
||||
] # Add an offset of start_time_s to each timestamp
|
||||
for ts in timestamps:
|
||||
reader.seek(max(0.0, ts - duration)) # Default to closest audio sample. Needs to be non-negative !
|
||||
status = reader.fill_buffer()
|
||||
if status != 0:
|
||||
# Should not happen, but just in case
|
||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||
|
||||
current_audio_chunk = reader.pop_chunks()[0]
|
||||
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||
|
||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||
if ts - duration < 0:
|
||||
# No useful audio sample has been recorded
|
||||
if ts < 1 / audio_sample_rate:
|
||||
current_audio_chunk_data = torch.zeros(
|
||||
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||
)
|
||||
# At least one useful audio sample has been recorded
|
||||
else:
|
||||
# Remove the superfluous last samples of the audio chunk
|
||||
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||
# Pad the beginning of the audio chunk with zeros
|
||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||
current_audio_chunk_data = torch.nn.functional.pad(
|
||||
current_audio_chunk_data,
|
||||
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||
)
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(
|
||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||
)
|
||||
|
||||
audio_chunks.append(current_audio_chunk_data)
|
||||
|
||||
audio_chunks = torch.stack(audio_chunks)
|
||||
|
||||
assert len(timestamps) == len(audio_chunks)
|
||||
return audio_chunks
|
||||
|
||||
|
||||
def encode_audio(
|
||||
input_path: Path | str,
|
||||
output_path: Path | str,
|
||||
codec: str = "aac", # TODO(CarolinePascal) : investigate Fraunhofer FDK AAC (libfdk_aac) codec and and constant (file size control) /variable (quality control) bitrate options
|
||||
bit_rate: int | None = None,
|
||||
sample_rate: int | None = None,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
) -> None:
|
||||
"""Encodes an audio file using ffmpeg."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=overwrite)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python’s logging"
|
||||
logging.getLogger("libav").setLevel(log_level)
|
||||
|
||||
# Open input file
|
||||
with av.open(str(input_path), "r") as input:
|
||||
input_stream = input.streams.audio[0] # Assuming the first stream is the audio stream to be encoded
|
||||
|
||||
# Define sub-sampling options
|
||||
if sample_rate is None:
|
||||
sample_rate = input_stream.rate
|
||||
|
||||
# Create and open output file (overwrite by default)
|
||||
with av.open(str(output_path), "w") as output:
|
||||
output_stream = output.add_stream(
|
||||
codec, rate=sample_rate, layout=CHANNELS_LAYOUTS_MAPPING[input_stream.channels]
|
||||
)
|
||||
|
||||
if bit_rate is not None:
|
||||
output_stream.bit_rate = bit_rate
|
||||
|
||||
# Loop through input WAV packets and encode them
|
||||
for input_frame in input.decode(
|
||||
input_stream
|
||||
): # This step handles both demuxing and decoding under the hood
|
||||
packet = output_stream.encode(input_frame)
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Flush the encoder
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
output.mux(packet)
|
||||
|
||||
# Reset logging level
|
||||
if log_level is not None:
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not output_path.exists():
|
||||
raise OSError(f"Audio encoding did not work. File not found: {output_path}.")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
from lerobot.datasets.utils import load_audio_from_path, load_image_as_numpy
|
||||
|
||||
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
|
||||
|
||||
@@ -245,6 +245,20 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
|
||||
return images
|
||||
|
||||
|
||||
def sample_audio_from_path(audio_path: str) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a WAV file."""
|
||||
data = load_audio_from_path(audio_path)
|
||||
sampled_indices = sample_indices(len(data))
|
||||
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def sample_audio_from_data(data: np.ndarray) -> np.ndarray:
|
||||
"""Samples audio data from an audio recording stored in a numpy array."""
|
||||
sampled_indices = sample_indices(len(data))
|
||||
return data[sampled_indices]
|
||||
|
||||
|
||||
def _reshape_stats_by_axis(
|
||||
stats: dict[str, np.ndarray],
|
||||
axis: int | tuple[int, ...] | None,
|
||||
@@ -512,6 +526,13 @@ def compute_episode_stats(
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
elif features[key]["dtype"] == "audio":
|
||||
try:
|
||||
ep_ft_array = sample_audio_from_path(data[0])
|
||||
except TypeError: # Should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
ep_ft_array = sample_audio_from_data(data)
|
||||
axes_to_reduce = 0
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_PREFIX, REWARD
|
||||
from lerobot.utils.constants import ACTION, OBS_PREFIX, REWARD
|
||||
|
||||
IMAGENET_STATS = {
|
||||
"mean": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)
|
||||
@@ -59,12 +59,7 @@ def resolve_delta_timestamps(
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.reward_delta_indices]
|
||||
if key == ACTION and cfg.action_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
|
||||
|
||||
# Check for image-specific delta indices first (e.g., for video encoding)
|
||||
if key.startswith(OBS_IMAGES) and cfg.image_observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.image_observation_delta_indices]
|
||||
# Fall back to generic observation delta indices for all observations
|
||||
elif key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
if key.startswith(OBS_PREFIX) and cfg.observation_delta_indices is not None:
|
||||
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
|
||||
|
||||
if len(delta_timestamps) == 0:
|
||||
|
||||
@@ -33,12 +33,16 @@ import torch.utils
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
|
||||
from lerobot.datasets.audio_utils import decode_audio, encode_audio, get_audio_info
|
||||
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
DEFAULT_RAW_AUDIO_PATH,
|
||||
INFO_PATH,
|
||||
_validate_feature_names,
|
||||
check_delta_timestamps,
|
||||
@@ -68,13 +72,15 @@ from lerobot.datasets.utils import (
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
concatenate_media_files,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_media_duration_in_s,
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.microphones import Microphone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
@@ -214,6 +220,19 @@ class LeRobotDatasetMetadata:
|
||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
def get_audio_file_path(self, ep_index: int, audio_key: str) -> Path:
|
||||
if self.episodes is None:
|
||||
self.episodes = load_episodes(self.root)
|
||||
if ep_index >= len(self.episodes):
|
||||
raise IndexError(
|
||||
f"Episode index {ep_index} out of range. Episodes: {len(self.episodes) if self.episodes else 0}"
|
||||
)
|
||||
ep = self.episodes[ep_index]
|
||||
chunk_idx = ep[f"audio/{audio_key}/chunk_index"]
|
||||
file_idx = ep[f"audio/{audio_key}/file_index"]
|
||||
fpath = self.audio_path.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||
return Path(fpath)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
@@ -224,6 +243,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def audio_path(self) -> str | None:
|
||||
"""Formattable string for the audio files."""
|
||||
return self.info["audio_path"]
|
||||
|
||||
@property
|
||||
def robot_type(self) -> str | None:
|
||||
"""Robot type used in recording this dataset."""
|
||||
@@ -254,6 +278,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def audio_keys(self) -> list[str]:
|
||||
"""Keys to access audio modalities."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "audio"]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[str, list | dict]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
@@ -294,6 +323,11 @@ class LeRobotDatasetMetadata:
|
||||
"""Max size of video file in mega bytes."""
|
||||
return self.info["video_files_size_in_mb"]
|
||||
|
||||
@property
|
||||
def audio_files_size_in_mb(self) -> int:
|
||||
"""Max size of audio file in mega bytes."""
|
||||
return self.info["audio_files_size_in_mb"]
|
||||
|
||||
def get_task_index(self, task: str) -> int | None:
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
@@ -435,11 +469,27 @@ class LeRobotDatasetMetadata:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
def update_audio_info(self, audio_key: str | None = None) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode audio, implicitly assuming that all audio have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
if audio_key is not None and audio_key not in self.audio_keys:
|
||||
raise ValueError(f"Audio key {audio_key} not found in dataset")
|
||||
|
||||
audio_keys = [audio_key] if audio_key is not None else self.audio_keys
|
||||
for key in audio_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
audio_path = self.root / self.audio_path.format(audio_key=key, chunk_index=0, file_index=0)
|
||||
self.info["features"][key]["info"] = get_audio_info(audio_path)
|
||||
self.info["features"][key]["info"]["start_time_s"] = DEFAULT_INITIAL_AUDIO_BUFFER_DURATION
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> None:
|
||||
"""Update chunk and file size settings after dataset creation.
|
||||
|
||||
@@ -451,6 +501,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: Maximum number of files per chunk directory. If None, keeps current value.
|
||||
data_files_size_in_mb: Maximum size for data parquet files in MB. If None, keeps current value.
|
||||
video_files_size_in_mb: Maximum size for video files in MB. If None, keeps current value.
|
||||
audio_files_size_in_mb: Maximum size for audio files in MB. If None, keeps current value.
|
||||
"""
|
||||
if chunks_size is not None:
|
||||
if chunks_size <= 0:
|
||||
@@ -467,6 +518,11 @@ class LeRobotDatasetMetadata:
|
||||
raise ValueError(f"video_files_size_in_mb must be positive, got {video_files_size_in_mb}")
|
||||
self.info["video_files_size_in_mb"] = video_files_size_in_mb
|
||||
|
||||
if audio_files_size_in_mb is not None:
|
||||
if audio_files_size_in_mb <= 0:
|
||||
raise ValueError(f"audio_files_size_in_mb must be positive, got {audio_files_size_in_mb}")
|
||||
self.info["audio_files_size_in_mb"] = audio_files_size_in_mb
|
||||
|
||||
# Update the info file on disk
|
||||
write_info(self.info, self.root)
|
||||
|
||||
@@ -474,12 +530,13 @@ class LeRobotDatasetMetadata:
|
||||
"""Get current chunk and file size settings.
|
||||
|
||||
Returns:
|
||||
Dict containing chunks_size, data_files_size_in_mb, and video_files_size_in_mb.
|
||||
Dict containing chunks_size, data_files_size_in_mb, video_files_size_in_mb, and audio_files_size_in_mb.
|
||||
"""
|
||||
return {
|
||||
"chunks_size": self.chunks_size,
|
||||
"data_files_size_in_mb": self.data_files_size_in_mb,
|
||||
"video_files_size_in_mb": self.video_files_size_in_mb,
|
||||
"audio_files_size_in_mb": self.audio_files_size_in_mb,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
@@ -506,6 +563,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
@@ -529,6 +587,7 @@ class LeRobotDatasetMetadata:
|
||||
chunks_size,
|
||||
data_files_size_in_mb,
|
||||
video_files_size_in_mb,
|
||||
audio_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
@@ -564,7 +623,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
download_videos: bool = True,
|
||||
download_audio: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
):
|
||||
@@ -598,6 +659,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
task-conditioned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- videos (optional) from which frames are loaded to be synchronous with data from parquet files.
|
||||
- audio (optional) from which audio is loaded to be synchronous with data from parquet files.
|
||||
|
||||
A typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
@@ -623,19 +685,37 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.parquet
|
||||
└── videos
|
||||
├── observation.images.laptop
|
||||
├── videos
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── chunk-000
|
||||
│ │ │ ├── file-000.mp4
|
||||
│ │ │ ├── file-001.mp4
|
||||
│ │ │ └── ...
|
||||
│ │ ├── chunk-001
|
||||
│ │ │ └── ...
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── audio
|
||||
├── observation.audio.laptop
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── observation.images.phone
|
||||
├── observation.audio.phone
|
||||
│ ├── chunk-000
|
||||
│ │ ├── file-000.mp4
|
||||
│ │ ├── file-001.mp4
|
||||
│ │ ├── file-000.m4a
|
||||
│ │ ├── file-001.m4a
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ └── ...
|
||||
@@ -675,8 +755,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
download_audio (bool, optional): Flag to download the audio. Defaults to True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
audio_backend (str | None, optional): Audio backend to use for decoding audio. Defaults to 'torchcodec'.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
@@ -694,6 +776,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.audio_backend = (
|
||||
audio_backend if audio_backend else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
@@ -766,6 +851,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
push_audio: bool = True,
|
||||
private: bool = False,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
upload_large_folder: bool = False,
|
||||
@@ -774,6 +860,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not push_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.create_repo(
|
||||
@@ -828,7 +916,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download(self, download_videos: bool = True) -> None:
|
||||
def download(self, download_videos: bool = True, download_audio: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
@@ -836,8 +924,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
files = None
|
||||
ignore_patterns = []
|
||||
if not download_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
if not download_audio:
|
||||
ignore_patterns.append("audio/")
|
||||
if self.episodes is not None:
|
||||
files = self.get_episodes_file_paths()
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
@@ -852,6 +944,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += video_files
|
||||
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
audio_files = [
|
||||
str(self.meta.get_compressed_audio_file_path(ep_idx, audio_key))
|
||||
for audio_key in self.meta.audio_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
fpaths += audio_files
|
||||
|
||||
# episodes are stored in the same files, so we return unique paths only
|
||||
fpaths = list(set(fpaths))
|
||||
return fpaths
|
||||
@@ -864,7 +965,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return hf_dataset
|
||||
|
||||
def _check_cached_episodes_sufficient(self) -> bool:
|
||||
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||
"""Check if the cached dataset contains all requested episodes and their video and audio files."""
|
||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||
return False
|
||||
|
||||
@@ -892,6 +993,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if not video_path.exists():
|
||||
return False
|
||||
|
||||
# Check if all required audio files exist
|
||||
if len(self.meta.audio_keys) > 0:
|
||||
for ep_idx in requested_episodes:
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
if not audio_path.exists():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def create_hf_dataset(self) -> datasets.Dataset:
|
||||
@@ -970,7 +1079,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
for key in self.meta.video_keys + self.meta.audio_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
@@ -985,7 +1094,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
Query dataset for indices across keys, skipping video keys and audio keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
@@ -997,7 +1106,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
if key in self.meta.video_keys or key in self.meta.audio_keys:
|
||||
continue
|
||||
# Map absolute indices to relative indices if needed
|
||||
relative_indices = (
|
||||
@@ -1032,6 +1141,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return item
|
||||
|
||||
# TODO(CarolinePascal): add variable query durations
|
||||
def _query_audio(
|
||||
self, query_timestamps: dict[str, list[float]], query_duration: float, ep_idx: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
item = {}
|
||||
for audio_key, query_ts in query_timestamps.items():
|
||||
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
|
||||
# Thus we load the start timestamp of the episode on this mp4 and,
|
||||
# shift the query timestamp accordingly.
|
||||
from_timestamp = ep[f"audio/{audio_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
|
||||
audio_path = self.root / self.meta.get_audio_file_path(ep_idx, audio_key)
|
||||
start_time_s = self.meta.features[audio_key]["info"].get("start_time_s", 0.0)
|
||||
audio_chunk = decode_audio(
|
||||
audio_path, shifted_query_ts, query_duration, start_time_s, self.audio_backend
|
||||
)
|
||||
item[audio_key] = audio_chunk.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def _ensure_hf_dataset_loaded(self):
|
||||
"""Lazy load the HF dataset only when needed for reading."""
|
||||
if self._lazy_loading or self.hf_dataset is None:
|
||||
@@ -1061,11 +1192,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0 or len(self.meta.audio_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
audio_chunks = self._query_audio(query_timestamps, DEFAULT_AUDIO_CHUNK_DURATION, ep_idx)
|
||||
item = {**item, **video_frames, **audio_chunks}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
@@ -1113,6 +1245,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_raw_audio_file_path(self, episode_index: int, audio_key: str) -> Path:
|
||||
fpath = DEFAULT_RAW_AUDIO_PATH.format(audio_key=audio_key, episode_index=episode_index)
|
||||
return self.root / fpath
|
||||
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
@@ -1165,11 +1301,43 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
elif self.features[key]["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
else: # Otherwise, only the audio file path is stored in the episode buffer
|
||||
if frame_index == 0:
|
||||
audio_path = self._get_raw_audio_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], audio_key=key
|
||||
)
|
||||
self.episode_buffer[key].append(str(audio_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def add_microphone_recording(self, microphone_key: str, microphone: Microphone) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by the microphone and directly writes it in a .wav file.
|
||||
"""
|
||||
|
||||
audio_file = self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
microphone.start_recording(output_file=audio_file)
|
||||
|
||||
def add_microphones_recordings(self, microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Starts recording audio data provided by multiple microphones and directly writes it in appropriate .wav files.
|
||||
"""
|
||||
|
||||
output_files = []
|
||||
for microphone_key in microphones:
|
||||
output_files.append(
|
||||
self._get_raw_audio_file_path(self.num_episodes, "observation.audio." + microphone_key)
|
||||
)
|
||||
|
||||
async_microphones_start_recording(microphones, output_files)
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
@@ -1213,6 +1381,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
elif ft["dtype"] == "audio":
|
||||
if (
|
||||
self.meta.robot_type == "lekiwi"
|
||||
): # Raw data storage should only be triggered for LeKiwi robot, for which audio is stored chunk by chunk in a visual frame-like manner
|
||||
episode_buffer[key] = np.concatenate(episode_buffer[key], axis=0)
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
@@ -1221,9 +1395,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
has_audio_keys = len(self.meta.audio_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1260,21 +1435,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# TODO(Caroline): add parallel encoding for audio as well
|
||||
for audio_key in self.meta.audio_keys:
|
||||
ep_metadata.update(self._save_episode_audio(audio_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
|
||||
if has_video_keys and use_batched_encoding:
|
||||
if (has_video_keys or has_audio_keys) and use_batched_encoding:
|
||||
# Check if we should trigger batch encoding
|
||||
self.episodes_since_last_encoding += 1
|
||||
if self.episodes_since_last_encoding == self.batch_encoding_size:
|
||||
start_ep = self.num_episodes - self.batch_encoding_size
|
||||
end_ep = self.num_episodes
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_video_keys:
|
||||
self._batch_save_episode_video(start_ep, end_ep)
|
||||
if has_audio_keys:
|
||||
self._batch_save_episode_audio(start_ep, end_ep)
|
||||
self.episodes_since_last_encoding = 0
|
||||
|
||||
if not episode_data:
|
||||
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
|
||||
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
|
||||
self.clear_episode_buffer(
|
||||
delete_images=len(self.meta.image_keys) > 0, delete_audio=len(self.meta.audio_keys) > 0
|
||||
)
|
||||
|
||||
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
@@ -1325,7 +1509,70 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
# Save the current episode's audio metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(video_ep_df)
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
def _batch_save_episode_audio(self, start_episode: int, end_episode: int | None = None) -> None:
|
||||
"""
|
||||
Batch save audio for multiple episodes.
|
||||
|
||||
Args:
|
||||
start_episode: Starting episode index (inclusive)
|
||||
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode to the current episode.
|
||||
"""
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
f"Batch encoding {self.batch_encoding_size} audio for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
chunk_idx = self.meta.episodes[start_episode]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[start_episode]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding audio for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
or self.meta.episodes[ep_idx]["data/file_index"] != file_idx
|
||||
):
|
||||
# The current episode is in a new chunk or file.
|
||||
# Save previous episode dataframe and update the Hugging Face dataset by reloading it.
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
# Load new episode dataframe
|
||||
chunk_idx = self.meta.episodes[ep_idx]["data/chunk_index"]
|
||||
file_idx = self.meta.episodes[ep_idx]["data/file_index"]
|
||||
episode_df_path = self.root / DEFAULT_EPISODES_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
# Save the current episode's video metadata to the dataframe
|
||||
audio_ep_metadata = {}
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_ep_metadata.update(self._save_episode_audio(audio_key, ep_idx))
|
||||
audio_ep_metadata.pop("episode_index")
|
||||
audio_ep_df = pd.DataFrame(audio_ep_metadata, index=[ep_idx]).convert_dtypes(
|
||||
dtype_backend="pyarrow"
|
||||
) # allows NaN values along with integers
|
||||
|
||||
episode_df = episode_df.combine_first(audio_ep_df)
|
||||
episode_df.to_parquet(episode_df_path)
|
||||
self.meta.episodes = load_episodes(self.root)
|
||||
|
||||
@@ -1436,7 +1683,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
@@ -1482,7 +1729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest video file
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
@@ -1504,7 +1751,79 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
def _save_episode_audio(self, audio_key: str, episode_index: int) -> dict:
|
||||
# Encode episode audio into a temporary audio file
|
||||
ep_path = self._encode_temporary_episode_audio(audio_key, episode_index)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
if (
|
||||
episode_index == 0
|
||||
or self.meta.latest_episode is None
|
||||
or f"audio/{audio_key}/chunk_index" not in self.meta.latest_episode
|
||||
):
|
||||
# Initialize indices for a new dataset made of the first episode data
|
||||
chunk_idx, file_idx = 0, 0
|
||||
if self.meta.episodes is not None and len(self.meta.episodes) > 0:
|
||||
# It means we are resuming recording, so we need to load the latest episode
|
||||
# Update the indices to avoid overwriting the latest episode
|
||||
old_chunk_idx = self.meta.episodes[-1][f"audio/{audio_key}/chunk_index"]
|
||||
old_file_idx = self.meta.episodes[-1][f"audio/{audio_key}/file_index"]
|
||||
chunk_idx, file_idx = update_chunk_file_indices(
|
||||
old_chunk_idx, old_file_idx, self.meta.chunks_size
|
||||
)
|
||||
latest_duration_in_s = 0.0
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
else:
|
||||
# Retrieve information from the latest updated audio file using latest_episode
|
||||
latest_ep = self.meta.latest_episode
|
||||
chunk_idx = latest_ep[f"audio/{audio_key}/chunk_index"][0]
|
||||
file_idx = latest_ep[f"audio/{audio_key}/file_index"][0]
|
||||
|
||||
latest_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||
latest_duration_in_s = latest_ep[f"audio/{audio_key}/to_timestamp"][0]
|
||||
|
||||
if latest_size_in_mb + ep_size_in_mb >= self.meta.audio_files_size_in_mb:
|
||||
# Move temporary episode audio to a new audio file in the dataset
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||
new_path = self.root / self.meta.audio_path.format(
|
||||
audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.move(str(ep_path), str(new_path))
|
||||
latest_duration_in_s = 0.0
|
||||
else:
|
||||
# Update latest audio file
|
||||
concatenate_media_files(
|
||||
[latest_path, ep_path],
|
||||
latest_path,
|
||||
)
|
||||
|
||||
# Remove temporary directory
|
||||
shutil.rmtree(str(ep_path.parent))
|
||||
|
||||
# Update audio info (only needed when first episode is encoded since it reads from episode 0)
|
||||
if episode_index == 0:
|
||||
self.meta.update_audio_info(audio_key)
|
||||
write_info(self.meta.info, self.meta.root) # ensure audio info always written properly
|
||||
|
||||
metadata = {
|
||||
"episode_index": episode_index,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx,
|
||||
f"audio/{audio_key}/file_index": file_idx,
|
||||
f"audio/{audio_key}/from_timestamp": latest_duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True, delete_audio: bool = True) -> None:
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1518,6 +1837,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Clean up audio files for the current episode buffer
|
||||
if delete_audio:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for audio_key in self.meta.audio_keys:
|
||||
audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
if audio_file.is_file():
|
||||
audio_file.unlink()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
@@ -1554,6 +1883,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
def _encode_temporary_episode_audio(self, audio_key: str, episode_index: int) -> Path:
|
||||
"""
|
||||
Use ffmpeg to convert raw audio files into m4a audio files.
|
||||
Note: `encode_episode_audio` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since audio encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{audio_key}_{episode_index:03d}.m4a"
|
||||
raw_audio_file = self._get_raw_audio_file_path(episode_index, audio_key)
|
||||
encode_audio(raw_audio_file, temp_path, overwrite=True)
|
||||
raw_audio_file.unlink()
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
@@ -1567,6 +1908,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
) -> "LeRobotDataset":
|
||||
@@ -1611,6 +1953,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
obj.audio_backend = (
|
||||
audio_backend if audio_backend is not None else "torchcodec"
|
||||
) # Waiting for torchcodec release #TODO(CarolinePascal)
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1631,6 +1976,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerances_s: dict | None = None,
|
||||
download_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
audio_backend: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
@@ -1648,6 +1994,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
tolerance_s=self.tolerances_s[repo_id],
|
||||
download_videos=download_videos,
|
||||
video_backend=video_backend,
|
||||
audio_backend=audio_backend,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
|
||||
@@ -36,6 +36,7 @@ from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from huggingface_hub.errors import RevisionNotFoundError
|
||||
from PIL import Image as PILImage
|
||||
from soundfile import read
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
@@ -50,6 +51,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -57,13 +59,19 @@ STATS_PATH = "meta/stats.json"
|
||||
EPISODES_DIR = "meta/episodes"
|
||||
DATA_DIR = "data"
|
||||
VIDEO_DIR = "videos"
|
||||
AUDIO_DIR = "audio"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_AUDIO_PATH = AUDIO_DIR + "/{audio_key}/" + CHUNK_FILE_PATTERN + ".m4a"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
DEFAULT_RAW_AUDIO_PATH = "raw_audio/{audio_key}/episode_{episode_index:06d}.wav"
|
||||
|
||||
DEFAULT_AUDIO_CHUNK_DURATION = 0.5 # seconds
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION = 1.0 # seconds
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
@@ -408,6 +416,16 @@ def load_image_as_numpy(
|
||||
return img_array
|
||||
|
||||
|
||||
def load_audio_from_path(fpath: str | Path) -> np.ndarray:
|
||||
audio_data, _ = read(fpath, dtype="float32")
|
||||
|
||||
# Fill missing channel dimension when loading mono audio data
|
||||
if audio_data.ndim == 1:
|
||||
audio_data = np.expand_dims(audio_data, axis=1)
|
||||
|
||||
return audio_data
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
|
||||
"""Convert a batch from a Hugging Face dataset to torch tensors.
|
||||
|
||||
@@ -576,7 +594,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
"""
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
if ft["dtype"] == "video" or ft["dtype"] == "audio":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
@@ -639,7 +657,12 @@ def hw_to_dataset_features(
|
||||
for key, ftype in hw_features.items()
|
||||
if ftype is float or (isinstance(ftype, PolicyFeature) and ftype.type != FeatureType.VISUAL)
|
||||
}
|
||||
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
|
||||
cam_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 3
|
||||
}
|
||||
mic_fts = {
|
||||
key: shape for key, shape in hw_features.items() if isinstance(shape, tuple) and len(shape) == 2
|
||||
}
|
||||
|
||||
if joint_fts and prefix == ACTION:
|
||||
features[prefix] = {
|
||||
@@ -662,6 +685,14 @@ def hw_to_dataset_features(
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
for key, parameters in mic_fts.items():
|
||||
features[f"{prefix}.audio.{key}"] = {
|
||||
"dtype": "audio",
|
||||
"shape": (len(parameters[1]),),
|
||||
"names": ["channels"],
|
||||
"info": {"sample_rate": parameters[0]},
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
|
||||
@@ -691,6 +722,8 @@ def build_dataset_frame(
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
elif ft["dtype"] == "audio":
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.audio.")]
|
||||
|
||||
return frame
|
||||
|
||||
@@ -724,6 +757,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif ft["dtype"] == "audio":
|
||||
type = FeatureType.AUDIO
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
@@ -802,6 +839,7 @@ def create_empty_dataset_info(
|
||||
chunks_size: int | None = None,
|
||||
data_files_size_in_mb: int | None = None,
|
||||
video_files_size_in_mb: int | None = None,
|
||||
audio_files_size_in_mb: int | None = None,
|
||||
) -> dict:
|
||||
"""Create a template dictionary for a new dataset's `info.json`.
|
||||
|
||||
@@ -811,6 +849,10 @@ def create_empty_dataset_info(
|
||||
features (dict): The LeRobot features dictionary for the dataset.
|
||||
use_videos (bool): Whether the dataset will store videos.
|
||||
robot_type (str | None): The type of robot used, if any.
|
||||
chunks_size (int | None): The maximum number of files per chunk directory.
|
||||
data_files_size_in_mb (int | None): The maximum size for data files in MB.
|
||||
video_files_size_in_mb (int | None): The maximum size for video files in MB.
|
||||
audio_files_size_in_mb (int | None): The maximum size for audio files in MB.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with the initial dataset metadata.
|
||||
@@ -824,10 +866,12 @@ def create_empty_dataset_info(
|
||||
"chunks_size": chunks_size or DEFAULT_CHUNK_SIZE,
|
||||
"data_files_size_in_mb": data_files_size_in_mb or DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
"video_files_size_in_mb": video_files_size_in_mb or DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
"audio_files_size_in_mb": audio_files_size_in_mb or DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": DEFAULT_DATA_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"audio_path": DEFAULT_AUDIO_PATH,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -1051,6 +1095,8 @@ def validate_feature_dtype_and_shape(
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "audio":
|
||||
return validate_feature_audio(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
@@ -1117,6 +1163,23 @@ def validate_feature_image_or_video(
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_audio(name: str, expected_shape: list[str], value: np.ndarray):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c = expected_shape
|
||||
if (len(actual_shape) != 2 and len(actual_shape) != 1) or actual_shape[-1] != c[
|
||||
-1
|
||||
]: # The number of frames might be different
|
||||
error_message += (
|
||||
f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{c}'.\n"
|
||||
)
|
||||
else:
|
||||
error_message += f"The feature '{name}' is expected to be of type 'np.ndarray', but type '{type(value)}' provided instead.\n"
|
||||
|
||||
return error_message
|
||||
|
||||
|
||||
def validate_feature_string(name: str, value: str) -> str:
|
||||
"""Validate a feature that is expected to be a string.
|
||||
|
||||
|
||||
@@ -59,6 +59,8 @@ from requests import HTTPError
|
||||
from lerobot.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -79,7 +81,7 @@ from lerobot.datasets.utils import (
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||
from lerobot.datasets.video_utils import concatenate_media_files, get_media_duration_in_s
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
@@ -311,12 +313,12 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert videos of {video_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="video")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= video_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -352,7 +354,7 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_video_files(
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_VIDEO_PATH.format(video_key=video_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
@@ -367,8 +369,124 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key: str, video_f
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def get_audio_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
audio_keys = [key for key, ft in features.items() if ft["dtype"] == "audio"]
|
||||
return audio_keys
|
||||
|
||||
|
||||
def convert_audios(root: Path, new_root: Path, audio_file_size_in_mb: int):
|
||||
logging.info(f"Converting audios from {root} to {new_root}")
|
||||
|
||||
audio_keys = get_audio_keys(root)
|
||||
if len(audio_keys) == 0:
|
||||
return None
|
||||
|
||||
audio_keys = sorted(audio_keys)
|
||||
|
||||
eps_metadata_per_mic = []
|
||||
for microphone in audio_keys:
|
||||
eps_metadata = convert_audios_of_microphone(root, new_root, microphone, audio_file_size_in_mb)
|
||||
eps_metadata_per_mic.append(eps_metadata)
|
||||
|
||||
num_eps_per_mic = [len(eps_mic_map) for eps_mic_map in eps_metadata_per_mic]
|
||||
if len(set(num_eps_per_mic)) != 1:
|
||||
raise ValueError(f"All microphones dont have same number of episodes ({num_eps_per_mic}).")
|
||||
|
||||
episodes_metadata = []
|
||||
num_microphones = len(audio_keys)
|
||||
num_episodes = num_eps_per_mic[0]
|
||||
for ep_idx in tqdm.tqdm(range(num_episodes), desc="convert audios"):
|
||||
# Sanity check
|
||||
ep_ids = [
|
||||
eps_metadata_per_mic[mic_idx][ep_idx]["episode_index"] for mic_idx in range(num_microphones)
|
||||
]
|
||||
ep_ids += [ep_idx]
|
||||
if len(set(ep_ids)) != 1:
|
||||
raise ValueError(f"All episode indices need to match ({ep_ids}).")
|
||||
|
||||
ep_dict = {}
|
||||
for mic_idx in range(num_microphones):
|
||||
ep_dict.update(eps_metadata_per_mic[mic_idx][ep_idx])
|
||||
episodes_metadata.append(ep_dict)
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def convert_audios_of_microphone(root: Path, new_root: Path, audio_key: str, audio_file_size_in_mb: int):
|
||||
# Access old paths to m4a
|
||||
audios_dir = root / "audio"
|
||||
ep_paths = sorted(audios_dir.glob(f"*/{audio_key}/*.m4a"))
|
||||
|
||||
ep_idx = 0
|
||||
chunk_idx = 0
|
||||
file_idx = 0
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
episodes_metadata = []
|
||||
|
||||
for ep_path in tqdm.tqdm(ep_paths, desc=f"convert audios of {audio_key}"):
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_media_duration_in_s(ep_path, media_type="audio")
|
||||
|
||||
# Check if adding this episode would exceed the limit
|
||||
if size_in_mb + ep_size_in_mb >= audio_file_size_in_mb and len(paths_to_cat) > 0:
|
||||
# Size limit would be exceeded, save current accumulation WITHOUT this episode
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the file we just saved
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
# Move to next file and start fresh with current episode
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE)
|
||||
size_in_mb = 0
|
||||
duration_in_s = 0.0
|
||||
paths_to_cat = []
|
||||
|
||||
# Add current episode metadata
|
||||
ep_metadata = {
|
||||
"episode_index": ep_idx,
|
||||
f"audio/{audio_key}/chunk_index": chunk_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/file_index": file_idx, # Will be updated when file is saved
|
||||
f"audio/{audio_key}/from_timestamp": duration_in_s,
|
||||
f"audio/{audio_key}/to_timestamp": duration_in_s + ep_duration_in_s,
|
||||
}
|
||||
episodes_metadata.append(ep_metadata)
|
||||
|
||||
# Add current episode to accumulation
|
||||
paths_to_cat.append(ep_path)
|
||||
size_in_mb += ep_size_in_mb
|
||||
duration_in_s += ep_duration_in_s
|
||||
ep_idx += 1
|
||||
|
||||
# Write remaining videos if any
|
||||
if paths_to_cat:
|
||||
concatenate_media_files(
|
||||
paths_to_cat,
|
||||
new_root
|
||||
/ DEFAULT_AUDIO_PATH.format(audio_key=audio_key, chunk_index=chunk_idx, file_index=file_idx),
|
||||
)
|
||||
|
||||
# Update episodes metadata for the final file
|
||||
for i, _ in enumerate(paths_to_cat):
|
||||
past_ep_idx = ep_idx - len(paths_to_cat) + i
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/chunk_index"] = chunk_idx
|
||||
episodes_metadata[past_ep_idx][f"audio/{audio_key}/file_index"] = file_idx
|
||||
|
||||
return episodes_metadata
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None, episodes_audios=None
|
||||
):
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
@@ -392,16 +510,30 @@ def generate_episode_metadata_dict(
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if episodes_audios is None:
|
||||
ep_audio = {}
|
||||
else:
|
||||
ep_audio = episodes_audios[i]
|
||||
ep_ids_set.add(ep_audio["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict = {
|
||||
**ep_metadata,
|
||||
**ep_video,
|
||||
**ep_audio,
|
||||
**ep_legacy_metadata,
|
||||
**flatten_dict({"stats": ep_stats}),
|
||||
}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
ep_dict["meta/episodes/file_index"] = 0
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
def convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_video_metadata=None, episodes_audio_metadata=None
|
||||
):
|
||||
logging.info(f"Converting episodes metadata from {root} to {new_root}")
|
||||
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
@@ -410,13 +542,19 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
if episodes_audio_metadata is not None:
|
||||
num_eps_set.add(len(episodes_audio_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
episodes_legacy_metadata,
|
||||
episodes_metadata,
|
||||
episodes_stats,
|
||||
episodes_video_metadata,
|
||||
episodes_audio_metadata,
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
@@ -425,20 +563,22 @@ def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_
|
||||
write_stats(stats, new_root)
|
||||
|
||||
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb):
|
||||
def convert_info(root, new_root, data_file_size_in_mb, video_file_size_in_mb, audio_file_size_in_mb):
|
||||
info = load_info(root)
|
||||
info["codebase_version"] = V30
|
||||
del info["total_chunks"]
|
||||
del info["total_videos"]
|
||||
info["data_files_size_in_mb"] = data_file_size_in_mb
|
||||
info["video_files_size_in_mb"] = video_file_size_in_mb
|
||||
info["audio_files_size_in_mb"] = audio_file_size_in_mb
|
||||
info["data_path"] = DEFAULT_DATA_PATH
|
||||
info["video_path"] = DEFAULT_VIDEO_PATH if info["video_path"] is not None else None
|
||||
info["audio_path"] = DEFAULT_AUDIO_PATH if info["audio_path"] is not None else None
|
||||
info["fps"] = int(info["fps"])
|
||||
logging.info(f"Converting info from {root} to {new_root}")
|
||||
for key in info["features"]:
|
||||
if info["features"][key]["dtype"] == "video":
|
||||
# already has fps in video_info
|
||||
if info["features"][key]["dtype"] == "video" or info["features"][key]["dtype"] == "audio":
|
||||
# already has fps in video_info or audio_info
|
||||
continue
|
||||
info["features"][key]["fps"] = info["fps"]
|
||||
write_info(info, new_root)
|
||||
@@ -449,6 +589,7 @@ def convert_dataset(
|
||||
branch: str | None = None,
|
||||
data_file_size_in_mb: int | None = None,
|
||||
video_file_size_in_mb: int | None = None,
|
||||
audio_file_size_in_mb: int | None = None,
|
||||
root: str | Path | None = None,
|
||||
push_to_hub: bool = True,
|
||||
force_conversion: bool = False,
|
||||
@@ -457,6 +598,8 @@ def convert_dataset(
|
||||
data_file_size_in_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if video_file_size_in_mb is None:
|
||||
video_file_size_in_mb = DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
if audio_file_size_in_mb is None:
|
||||
audio_file_size_in_mb = DEFAULT_AUDIO_FILE_SIZE_IN_MB
|
||||
|
||||
# First check if the dataset already has a v3.0 version
|
||||
if root is None and not force_conversion:
|
||||
@@ -498,7 +641,10 @@ def convert_dataset(
|
||||
convert_tasks(root, new_root)
|
||||
episodes_metadata = convert_data(root, new_root, data_file_size_in_mb)
|
||||
episodes_videos_metadata = convert_videos(root, new_root, video_file_size_in_mb)
|
||||
convert_episodes_metadata(root, new_root, episodes_metadata, episodes_videos_metadata)
|
||||
episodes_audios_metadata = convert_audios(root, new_root, audio_file_size_in_mb)
|
||||
convert_episodes_metadata(
|
||||
root, new_root, episodes_metadata, episodes_videos_metadata, episodes_audios_metadata
|
||||
)
|
||||
|
||||
shutil.move(str(root), str(old_root))
|
||||
shutil.move(str(new_root), str(root))
|
||||
@@ -511,7 +657,7 @@ def convert_dataset(
|
||||
print(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||
pass
|
||||
hub_api.delete_files(
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*"],
|
||||
delete_patterns=["data/chunk*/episode_*", "meta/*.jsonl", "videos/chunk*", "audio/chunk*"],
|
||||
repo_id=repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
@@ -549,6 +695,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for data and 500 for videos.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--audio-file-size-in-mb",
|
||||
type=int,
|
||||
default=None,
|
||||
help="File size in MB. Defaults to 100 for audio.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
|
||||
@@ -397,42 +397,42 @@ def encode_video_frames(
|
||||
raise OSError(f"Video encoding did not work. File not found: {video_path}.")
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
def concatenate_media_files(
|
||||
input_media_paths: list[Path | str], output_media_path: Path, overwrite: bool = True
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
Concatenate multiple media files (video & audio) into a single media file using pyav.
|
||||
|
||||
This function takes a list of video input file paths and concatenates them into a single
|
||||
output video file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
This function takes a list of input media file paths and concatenates them into a single
|
||||
output media file. It uses ffmpeg's concat demuxer with stream copy mode for fast
|
||||
concatenation without re-encoding.
|
||||
|
||||
Args:
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
input_media_paths: Ordered list of input media file paths to concatenate.
|
||||
output_media_path: Path to the output media file.
|
||||
overwrite: Whether to overwrite the output media file if it already exists. Default is True.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input videos to have the same
|
||||
- Creates a temporary .ffconcat file and container audio/video file that are cleaned up after use.
|
||||
- Uses ffmpeg's concat demuxer which requires all input media files to have the same
|
||||
codec, resolution, and frame rate for proper concatenation.
|
||||
"""
|
||||
|
||||
output_video_path = Path(output_video_path)
|
||||
output_media_path = Path(output_media_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
if output_media_path.exists() and not overwrite:
|
||||
logging.warning(f"Media file already exists: {output_media_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_media_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
if len(input_media_paths) == 0:
|
||||
raise FileNotFoundError("No input media paths provided.")
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
# Create a temporary .ffconcat file to list the input media paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
for input_path in input_video_paths:
|
||||
for input_path in input_media_paths:
|
||||
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
|
||||
tmp_concatenate_file.flush()
|
||||
tmp_concatenate_path = tmp_concatenate_file.name
|
||||
@@ -442,11 +442,12 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
# Using an intermediate container to store the concatenated media file is necessary to avoid inplace concatenation read-write race conditions.
|
||||
with tempfile.NamedTemporaryFile(suffix=output_media_path.suffix, delete=False) as tmp_named_file:
|
||||
tmp_output_media_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
tmp_output_media_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
# Replicate input streams in output container
|
||||
@@ -461,6 +462,7 @@ def concatenate_video_files(
|
||||
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||
|
||||
# Demux + remux packets (no re-encode)
|
||||
last_dts = None
|
||||
for packet in input_container.demux():
|
||||
# Skip packets from un-mapped streams
|
||||
if packet.stream.index not in stream_map:
|
||||
@@ -469,6 +471,16 @@ def concatenate_video_files(
|
||||
# Skip demux flushing packets
|
||||
if packet.dts is None:
|
||||
continue
|
||||
else:
|
||||
# Enforce strictly increasing decoding timestamps (DTS)
|
||||
if last_dts is not None and packet.dts <= last_dts:
|
||||
shift = last_dts - packet.dts + 1
|
||||
packet.dts += shift
|
||||
packet.pts += shift # Presenting timestamps (PTS) are the same as DTS here
|
||||
logging.warning(
|
||||
f"Non-monotonic DTS; previous: {last_dts}, current: {packet.dts - shift}; changing to {packet.dts}. This may result in incorrect timestamps in the output file."
|
||||
)
|
||||
last_dts = packet.dts
|
||||
|
||||
output_stream = stream_map[packet.stream.index]
|
||||
packet.stream = output_stream
|
||||
@@ -476,7 +488,7 @@ def concatenate_video_files(
|
||||
|
||||
input_container.close()
|
||||
output_container.close()
|
||||
shutil.move(tmp_output_video_path, output_video_path)
|
||||
shutil.move(tmp_output_media_path, output_media_path)
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
@@ -512,38 +524,6 @@ with warnings.catch_warnings():
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
with av.open(str(video_path), "r") as audio_file:
|
||||
try:
|
||||
audio_stream = audio_file.streams.audio[0]
|
||||
except IndexError:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
return {"has_audio": False}
|
||||
|
||||
audio_info["audio.channels"] = audio_stream.channels
|
||||
audio_info["audio.codec"] = audio_stream.codec.canonical_name
|
||||
# In an ideal loseless case : bit depth x sample rate x channels = bit rate.
|
||||
# In an actual compressed case, the bit rate is set according to the compression level : the lower the bit rate, the more compression is applied.
|
||||
audio_info["audio.bit_rate"] = audio_stream.bit_rate
|
||||
audio_info["audio.sample_rate"] = audio_stream.sample_rate # Number of samples per second
|
||||
# In an ideal loseless case : fixed number of bits per sample.
|
||||
# In an actual compressed case : variable number of bits per sample (often reduced to match a given depth rate).
|
||||
audio_info["audio.bit_depth"] = audio_stream.format.bits
|
||||
audio_info["audio.channel_layout"] = audio_stream.layout.name
|
||||
audio_info["has_audio"] = True
|
||||
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
@@ -573,9 +553,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Reset logging level
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
@@ -590,22 +567,22 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
def get_media_duration_in_s(media_path: Path | str, media_type: str = "video") -> float:
|
||||
"""
|
||||
Get the duration of a video file in seconds using PyAV.
|
||||
Get the duration of a media file (video & audio) in seconds using PyAV.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
media_path: Path to the media file.
|
||||
|
||||
Returns:
|
||||
Duration of the video in seconds.
|
||||
Duration of the media file in seconds.
|
||||
"""
|
||||
with av.open(str(video_path)) as container:
|
||||
# Get the first video stream
|
||||
video_stream = container.streams.video[0]
|
||||
with av.open(str(media_path)) as container:
|
||||
# Get the first stream
|
||||
stream = container.streams.video[0] if media_type == "video" else container.streams.audio[0]
|
||||
# Calculate duration: stream.duration * stream.time_base gives duration in seconds
|
||||
if video_stream.duration is not None:
|
||||
duration = float(video_stream.duration * video_stream.time_base)
|
||||
if stream.duration is not None:
|
||||
duration = float(stream.duration * stream.time_base)
|
||||
else:
|
||||
# Fallback to container duration if stream duration is not available
|
||||
duration = float(container.duration / av.time_base)
|
||||
@@ -614,12 +591,12 @@ def get_video_duration_in_s(video_path: Path | str) -> float:
|
||||
|
||||
class VideoEncodingManager:
|
||||
"""
|
||||
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
|
||||
Context manager that ensures proper video and audio encoding and data cleanup even if exceptions occur.
|
||||
|
||||
This manager handles:
|
||||
- Batch encoding for any remaining episodes when recording interrupted
|
||||
- Cleaning up temporary image files from interrupted episodes
|
||||
- Removing empty image directories
|
||||
- Cleaning up temporary image and audio files from interrupted episodes
|
||||
- Removing empty image and audio directories
|
||||
|
||||
Args:
|
||||
dataset: The LeRobotDataset instance
|
||||
@@ -646,6 +623,7 @@ class VideoEncodingManager:
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
self.dataset._batch_save_episode_video(start_ep, end_ep)
|
||||
self.dataset._batch_save_episode_audio(start_ep, end_ep)
|
||||
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
@@ -662,6 +640,15 @@ class VideoEncodingManager:
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
for key in self.dataset.meta.audio_keys:
|
||||
audio_file = self.dataset._get_raw_audio_file_path(
|
||||
episode_index=interrupted_episode_index, audio_key=key
|
||||
)
|
||||
if audio_file.exists():
|
||||
logging.debug(
|
||||
f"Cleaning up interrupted episode audio for episode {interrupted_episode_index}, microphone {key}"
|
||||
)
|
||||
audio_file.unlink()
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
@@ -675,4 +662,16 @@ class VideoEncodingManager:
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
# Clean up any remaining audio directory if it's empty
|
||||
audio_dir = self.dataset.root / "raw_audio"
|
||||
# Check for any remaining WAV files
|
||||
wav_files = list(audio_dir.rglob("*.wav"))
|
||||
if len(wav_files) == 0:
|
||||
# Only remove the raw_audio directory if no WAV files remain
|
||||
if audio_dir.exists():
|
||||
shutil.rmtree(audio_dir)
|
||||
logging.debug("Cleaned up empty audio directory")
|
||||
else:
|
||||
logging.debug(f"Audio directory is not empty, containing {len(wav_files)} WAV files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
from .utils import make_microphones_from_configs
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
from .sensor_anyskin import AnyskinSensor
|
||||
@@ -0,0 +1,45 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("anyskin")
|
||||
@dataclass
|
||||
class AnyskinSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for Anyskin tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for Anyskin tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
AnyskinSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
AnyskinSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
sensor_id: int = 0
|
||||
burst_mode: bool = True
|
||||
temp_filtered: bool = False
|
||||
@@ -0,0 +1,473 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the AnyskinSensor class for capturing tactile data from Anyskin tactile sensors.
|
||||
"""
|
||||
|
||||
from doctest import master
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.hub import T
|
||||
import numpy as np
|
||||
from serial import Serial, serialutil
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_anyskin import AnyskinSensorConfig
|
||||
|
||||
from anyskin import AnySkinBase, AnySkinDummy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_MAGNETS_CHANNELS = 5
|
||||
|
||||
class AnyskinSensor(Microphone):
|
||||
"""
|
||||
The AnyskinSensor class handles all Anyskin tactile sensors.
|
||||
|
||||
A AnyskinSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import AnyskinSensorConfig
|
||||
|
||||
config = AnyskinSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = AnyskinSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: AnyskinSensorConfig):
|
||||
""" "
|
||||
Initializes the AnyskinSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the busy_wait function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(tactile_sensor: AnySkinBase):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
if audio_callback_start_event.is_set():
|
||||
timestamp, indata = tactile_sensor.get_sample()
|
||||
indata = indata.reshape(-1, MAX_MAGNETS_CHANNELS)
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
try:
|
||||
tactile_sensor = AnySkinBase(
|
||||
num_mags=MAX_MAGNETS_CHANNELS,
|
||||
port=sensor_port,
|
||||
baudrate=baud_rate,
|
||||
burst_mode=True,
|
||||
device_id=0, #TODO(CarolinePascal): create an abstract increasing id for each sensor
|
||||
temp_filtered=False,
|
||||
) #TODO(CarolinePascal): add timeout on serial connection ?
|
||||
except (serialutil.SerialException, AttributeError) as e:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {sensor_port}: {e}")
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(tactile_sensor) # Initial flush is already done in the constructor.
|
||||
record_is_started_event.clear()
|
||||
tactile_sensor.close() # Closes the inherited serial connection.
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=AnyskinSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # Subtype for float32 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MicrophoneConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
sample_rate: int | None = None
|
||||
channels: list[int] | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from threading import Barrier
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
|
||||
|
||||
class Microphone(abc.ABC):
|
||||
"""Base class for microphone implementations.
|
||||
|
||||
Defines a standard interface for microphone operations across different backends.
|
||||
Subclasses must implement all abstract methods.
|
||||
|
||||
Manages basic microphone properties (sample rate, channels) and core operations:
|
||||
- Connection/disconnection
|
||||
- Start/stop recording
|
||||
- Audio chunk reading
|
||||
|
||||
Attributes:
|
||||
sample_rate (int | None): Configured sample rate in Hz
|
||||
channels (list[int] | None): List of channel numbers to record
|
||||
|
||||
Example:
|
||||
class MyMicrophone(Microphone):
|
||||
def __init__(self, config): ...
|
||||
@property
|
||||
def is_connected(self) -> bool: ...
|
||||
def connect(self): ...
|
||||
# Plus other required methods
|
||||
"""
|
||||
|
||||
def __init__(self, config: MicrophoneConfig):
|
||||
"""Initialize the microphone with the given configuration.
|
||||
|
||||
Args:
|
||||
config: Microphone configuration containing sample rate and channels.
|
||||
"""
|
||||
self.sample_rate: int | None = config.sample_rate
|
||||
self.channels: list[int] | None = config.channels
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the microphone is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the microphone is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is recording, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the microphone is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the microphone is writing to a file, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available microphones connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> None:
|
||||
"""Establish connection to the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""Start recording audio from the microphone.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded audio.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple microphones start recording at the same time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the microphone.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the microphone."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the microphone and release any resources."""
|
||||
pass
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from .microphone_portaudio import PortAudioMicrophone
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("portaudio")
|
||||
@dataclass
|
||||
class PortAudioMicrophoneConfig(MicrophoneConfig):
|
||||
"""Configuration class for PortAudio-based microphone devices.
|
||||
|
||||
This class provides configuration options for microphones accessed through PortAudio with the sounddevice Python package.
|
||||
including device index, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
PortAudioMicrophoneConfig(0, 16000, [1]) # Device index 0, 16000Hz, mono
|
||||
PortAudioMicrophoneConfig(1, 44100, [1, 2]) # Device index 1, 44100Hz, stereo
|
||||
```
|
||||
|
||||
Attributes:
|
||||
microphone_index: Device index for the microphone.
|
||||
sample_rate: Sample rate in Hz for the microphone.
|
||||
channels: List of channel numbers to use for the microphone.
|
||||
"""
|
||||
|
||||
microphone_index: int
|
||||
@@ -0,0 +1,394 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sounddevice import PortAudioError
|
||||
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
|
||||
# --- Interface definitions for InputStream ---
|
||||
class IInputStream(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def stop(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ISounddeviceSDK(abc.ABC):
|
||||
"""Interface defining the contract for the Sounddevice SDK."""
|
||||
|
||||
InputStream: type[IInputStream]
|
||||
|
||||
@abc.abstractmethod
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
|
||||
# --- Real SDK Adapter ---
|
||||
|
||||
|
||||
class SounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Adapts the real sounddevice library to the ISounddeviceSDK interface."""
|
||||
|
||||
_sounddevice = None
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
import sounddevice
|
||||
|
||||
SounddeviceSDKAdapter._sounddevice = sounddevice
|
||||
except ImportError as e:
|
||||
raise ImportError("sounddevice library not found") from e
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class RealInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | np.dtype | None = None,
|
||||
latency: float | str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
import sounddevice
|
||||
|
||||
self._input_stream = sounddevice.InputStream(
|
||||
samplerate=samplerate,
|
||||
blocksize=blocksize,
|
||||
device=device,
|
||||
channels=channels,
|
||||
dtype=dtype,
|
||||
latency=latency,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
self._input_stream.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._input_stream.stop()
|
||||
|
||||
def close(self) -> None:
|
||||
self._input_stream.close()
|
||||
|
||||
def __del__(self):
|
||||
self._input_stream.stop()
|
||||
self._input_stream.close()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
return self._input_stream.active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
return self._input_stream.stopped
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._input_stream.closed
|
||||
|
||||
InputStream = RealInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
return SounddeviceSDKAdapter._sounddevice.query_devices(device, kind)
|
||||
|
||||
|
||||
# Emulates a 48kHz stereo microphone
|
||||
VALID_DTYPE = {
|
||||
"float32",
|
||||
"int32",
|
||||
"int16",
|
||||
"int8",
|
||||
"uint8",
|
||||
np.float32,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
}
|
||||
VALID_LATENCY = {"low", "high"}
|
||||
|
||||
VALID_DEVICES = [
|
||||
{
|
||||
"index": 0,
|
||||
"name": "Built-in Microphone",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.001,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.01,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"name": "Built-in Output",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 0,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.04,
|
||||
"default_low_output_latency": 0.04,
|
||||
"default_high_input_latency": 0.12,
|
||||
"default_high_output_latency": 0.12,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"name": "USB Audio Device",
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 1,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.03,
|
||||
"default_low_output_latency": 0.01,
|
||||
"default_high_input_latency": 0.04,
|
||||
"default_high_output_latency": 0.03,
|
||||
"default_samplerate": 16000.0,
|
||||
},
|
||||
]
|
||||
|
||||
# -- Fake SDK Adapter ---
|
||||
|
||||
|
||||
class FakeSounddeviceSDKAdapter(ISounddeviceSDK):
|
||||
"""Implements the ISounddeviceSDK interface with fake behaviour for testing."""
|
||||
|
||||
# --- Inner Class Implementation ---
|
||||
class FakeInputStream(IInputStream):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float | None = None,
|
||||
blocksize: int | None = None,
|
||||
device: int | str | None = None,
|
||||
channels: int | None = None,
|
||||
dtype: str | None = None,
|
||||
latency: str | None = None,
|
||||
callback: Callable[[Any, int, Any, Any], None] | None = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.blocksize = blocksize
|
||||
self.device = device
|
||||
self.channels = channels
|
||||
self.dtype = dtype
|
||||
self.latency = latency
|
||||
self.callback = callback
|
||||
|
||||
self._validate_settings()
|
||||
|
||||
self._active = False
|
||||
self._closed = False
|
||||
|
||||
if self.callback is not None:
|
||||
self._streaming_thread = Thread(target=self._streaming_loop, daemon=True)
|
||||
self._streaming_thread_stop_event = Event()
|
||||
|
||||
@property
|
||||
def active(self) -> bool:
|
||||
"""True when the stream is active, False otherwise."""
|
||||
return self._active
|
||||
|
||||
@property
|
||||
def stopped(self) -> bool:
|
||||
"""True when the stream is stopped, False otherwise."""
|
||||
return not self._active
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
"""True after a call to close(), False otherwise."""
|
||||
return self._closed
|
||||
|
||||
def _get_device_info(self):
|
||||
"""Returns the device info for the device."""
|
||||
for device in VALID_DEVICES:
|
||||
if (isinstance(self.device, int) and device["index"] == self.device) or (
|
||||
isinstance(self.device, str) and device["name"] == self.device
|
||||
):
|
||||
return device
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
|
||||
def _validate_device(self):
|
||||
"""Validates the device against the valid devices."""
|
||||
valid_device_indices = [device["index"] for device in VALID_DEVICES]
|
||||
valid_device_names = [device["name"] for device in VALID_DEVICES]
|
||||
|
||||
if self.device is not None:
|
||||
if isinstance(self.device, (int, str)):
|
||||
# Check if device index is valid
|
||||
if isinstance(self.device, int) and self.device not in valid_device_indices:
|
||||
raise PortAudioError(f"Error querying device {self.device}")
|
||||
|
||||
# Check if device name is valid
|
||||
if isinstance(self.device, str) and self.device not in valid_device_names:
|
||||
raise PortAudioError(f"No input device matching {self.device}")
|
||||
else:
|
||||
raise PortAudioError(f"Device must be int or str, got {type(self.device)}")
|
||||
else:
|
||||
# Default to first input device
|
||||
input_devices = [d for d in VALID_DEVICES if d["max_input_channels"] > 0]
|
||||
if input_devices:
|
||||
self.device = input_devices[0]["index"]
|
||||
|
||||
def _validate_samplerate(self):
|
||||
"""Validates the samplerate against the device's maximum samplerate."""
|
||||
device_info = self._get_device_info()
|
||||
if self.samplerate is None:
|
||||
self.samplerate = device_info["default_samplerate"]
|
||||
elif self.samplerate > device_info["default_samplerate"] or self.samplerate < 1000:
|
||||
raise PortAudioError("Error opening InputStream: Invalid sample rate")
|
||||
|
||||
def _validate_channels(self):
|
||||
"""Validates the channels against the device's maximum channels."""
|
||||
device_info = self._get_device_info()
|
||||
if self.channels is None:
|
||||
self.channels = device_info["max_input_channels"]
|
||||
elif self.channels > device_info["max_input_channels"] or self.channels < 1:
|
||||
raise PortAudioError("Error opening InputStream: Invalid number of channels")
|
||||
|
||||
def _validate_dtype(self):
|
||||
"""Validates the dtype against the valid dtypes."""
|
||||
if self.dtype is not None:
|
||||
if self.dtype not in VALID_DTYPE:
|
||||
raise PortAudioError("Invalid input sample format")
|
||||
else:
|
||||
self.dtype = "float32" # Default dtype
|
||||
|
||||
def _validate_latency(self):
|
||||
"""Validates the latency against the valid latencies."""
|
||||
if self.latency is not None:
|
||||
if self.latency not in VALID_LATENCY:
|
||||
raise PortAudioError("Invalid latency")
|
||||
else:
|
||||
self.latency = "low" # Default latency
|
||||
|
||||
if isinstance(self.latency, str):
|
||||
device_info = self._get_device_info()
|
||||
if self.latency == "low":
|
||||
self.latency = device_info["default_low_input_latency"]
|
||||
elif self.latency == "high":
|
||||
self.latency = device_info["default_high_input_latency"]
|
||||
|
||||
def _validate_settings(self):
|
||||
"""Validates the input parameters against available devices and valid options."""
|
||||
self._validate_device()
|
||||
self._validate_samplerate()
|
||||
self._validate_channels()
|
||||
self._validate_dtype()
|
||||
self._validate_latency()
|
||||
|
||||
def _simulated_audio_data(self) -> np.ndarray:
|
||||
"""Generates a simulated audio signal for testing purposes with proper value ranges."""
|
||||
duration_samples = int(self.samplerate * self.latency)
|
||||
|
||||
# Generate output according to dtype
|
||||
if self.dtype in {"float32", np.float32}:
|
||||
# Generate values between -1 and 1 for float32
|
||||
data = np.random.uniform(-1.0, 1.0, (duration_samples, self.channels)).astype(self.dtype)
|
||||
else:
|
||||
# Use np.iinfo to get proper range for integer types
|
||||
info = np.iinfo(self.dtype)
|
||||
data = np.random.randint(
|
||||
info.min, info.max + 1, (duration_samples, self.channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def _streaming_loop(self):
|
||||
if self.callback is not None:
|
||||
while not self._streaming_thread_stop_event.is_set():
|
||||
precise_sleep(self.latency)
|
||||
tmp_data = self._simulated_audio_data()
|
||||
self.callback(
|
||||
tmp_data,
|
||||
len(tmp_data),
|
||||
time.perf_counter(),
|
||||
None,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the fake input stream."""
|
||||
if not self.active and self.callback is not None:
|
||||
self._streaming_thread.start()
|
||||
self._active = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the fake input stream."""
|
||||
if self.callback is not None:
|
||||
self._streaming_thread_stop_event.set()
|
||||
self._streaming_thread.join()
|
||||
self._active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the fake input stream."""
|
||||
if self.active and self.callback is not None:
|
||||
self.stop()
|
||||
self._active = False
|
||||
self._closed = True
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
InputStream = FakeInputStream
|
||||
|
||||
def query_devices(self, device: int | str | None = None, kind: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Returns a realistic list of audio devices including speakers and microphones."""
|
||||
if device is not None:
|
||||
# Return specific device
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (isinstance(device, int) and valid_device["index"] == device) or (
|
||||
isinstance(device, str) and valid_device["name"] == device
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"Error querying device {device}")
|
||||
|
||||
elif kind is not None:
|
||||
for valid_device in VALID_DEVICES:
|
||||
if (
|
||||
valid_device["max_input_channels"] > 0
|
||||
and kind == "input"
|
||||
or valid_device["max_output_channels"] > 0
|
||||
and kind == "output"
|
||||
):
|
||||
return valid_device
|
||||
raise PortAudioError(f"No {kind} device found")
|
||||
|
||||
return VALID_DEVICES
|
||||
@@ -0,0 +1,566 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the PortAudioMicrophone class for capturing audio from microphones using the PortAudio library through the sounddevice Python package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import ISounddeviceSDK, SounddeviceSDKAdapter
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_portaudio import PortAudioMicrophoneConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PortAudioMicrophone(Microphone):
|
||||
"""
|
||||
The PortAudioMicrophone class handles all microphones compatible with sounddevice (and the underlying PortAudio library). Most microphones and sound cards are compatible, across all OS (Linux, Mac, Windows).
|
||||
|
||||
A PortAudioMicrophone instance requires the sounddevice index of the microphone, which may be obtained using `python -m sounddevice`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import PortAudioMicrophoneConfig
|
||||
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=16000, channels=[1])
|
||||
microphone = PortAudioMicrophone(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: PortAudioMicrophoneConfig, sounddevice_sdk: ISounddeviceSDK = None):
|
||||
"""
|
||||
Initializes the PortAudioMicrophone instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the microphone.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
self.sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
else:
|
||||
self.sounddevice_sdk = sounddevice_sdk
|
||||
|
||||
# Microphone index
|
||||
self.microphone_index = config.microphone_index
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.microphone_index})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones(
|
||||
device: int | str | None = None, sounddevice_sdk: ISounddeviceSDK = None
|
||||
) -> list[dict[str, Any]] | dict[str, Any]:
|
||||
"""
|
||||
Detects available microphones connected to the system.
|
||||
|
||||
Args:
|
||||
device: The device to find microphones for. If None, all microphones are found.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected microphone : index, name, sample rate, channels.
|
||||
"""
|
||||
|
||||
if sounddevice_sdk is None:
|
||||
sounddevice_sdk = SounddeviceSDKAdapter()
|
||||
|
||||
found_microphones_info = []
|
||||
|
||||
devices = sounddevice_sdk.query_devices()
|
||||
for d in devices:
|
||||
if d["max_input_channels"] > 0:
|
||||
microphone_info = {
|
||||
"index": d["index"],
|
||||
"name": d["name"],
|
||||
"sample_rate": int(d["default_samplerate"]),
|
||||
"channels": np.arange(1, d["max_input_channels"] + 1),
|
||||
}
|
||||
|
||||
if device is None or (
|
||||
(isinstance(device, int) and d["index"] == device)
|
||||
or (isinstance(device, str) and d["name"] == device)
|
||||
):
|
||||
found_microphones_info.append(microphone_info)
|
||||
|
||||
if device is not None:
|
||||
if len(found_microphones_info) == 0:
|
||||
raise RuntimeError(f"No microphone found for device {device}")
|
||||
else:
|
||||
return found_microphones_info[0]
|
||||
|
||||
if len(found_microphones_info) == 0:
|
||||
logger.warning("No microphone found !")
|
||||
|
||||
return found_microphones_info
|
||||
|
||||
def _configure_capture_settings(self) -> None:
|
||||
"""
|
||||
Validates the microphone index, sample rate and channels settings specified in the constructor's config to the un-connected microphone.
|
||||
|
||||
This method actually checks the specified settings and fills the sample rate and channels settings if not specified before attempting to start a PortAudio stream.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If one of the specified settings is not compatible with the microphone.
|
||||
DeviceAlreadyConnectedError: If the microphone is connected when attempting to configure settings.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"Cannot configure settings for {self} as it is already connected."
|
||||
)
|
||||
|
||||
self._validate_microphone_index()
|
||||
self._validate_sample_rate()
|
||||
self._validate_channels()
|
||||
|
||||
def _validate_microphone_index(self) -> None:
|
||||
""" "Validates the microphone index against available devices by checking if it has at least one input channel."""
|
||||
|
||||
try:
|
||||
PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"{e}. Available microphones: {PortAudioMicrophone.find_microphones(sounddevice_sdk=self.sounddevice_sdk)}"
|
||||
) from e
|
||||
|
||||
def _validate_sample_rate(self) -> None:
|
||||
"""Validates the sample rate against the actual microphone's default sample rate."""
|
||||
|
||||
actual_sample_rate = PortAudioMicrophone.find_microphones(
|
||||
self.microphone_index, self.sounddevice_sdk
|
||||
)["sample_rate"]
|
||||
|
||||
if self.sample_rate is not None:
|
||||
try:
|
||||
self.sample_rate = int(self.sample_rate)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot convert the provided sample rate ({self.sample_rate} Hz) to an integer."
|
||||
) from e
|
||||
|
||||
if self.sample_rate > actual_sample_rate or self.sample_rate < 1000:
|
||||
raise RuntimeError(
|
||||
f"Provided sample rate {self.sample_rate} is either too low or too high compared to the sample rate of the microphone {actual_sample_rate}."
|
||||
)
|
||||
else:
|
||||
if self.sample_rate < actual_sample_rate:
|
||||
logger.warning(
|
||||
"Provided sample rate is lower than the sample rate of the microphone. Performance may be impacted."
|
||||
)
|
||||
else:
|
||||
self.sample_rate = actual_sample_rate
|
||||
|
||||
def _validate_channels(self) -> None:
|
||||
"""Validates the channels against the actual microphone's maximum input channels."""
|
||||
|
||||
actual_channels = PortAudioMicrophone.find_microphones(self.microphone_index, self.sounddevice_sdk)[
|
||||
"channels"
|
||||
]
|
||||
|
||||
if self.channels is not None and len(self.channels) > 0:
|
||||
if not all(channel in actual_channels for channel in self.channels):
|
||||
raise RuntimeError(
|
||||
f"Some of the provided channels {self.channels} are outside the possible channel range of the microphone {actual_channels}."
|
||||
)
|
||||
else:
|
||||
self.channels = actual_channels
|
||||
|
||||
# Get channels index instead of number for slicing
|
||||
self.channels_index = np.array(self.channels) - 1
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Connects the microphone and checks if the requested acquisition parameters are compatible with the microphone.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Microphone {self.microphone_index} is already connected.")
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("float32"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.microphone_index,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
self.sounddevice_sdk,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects the microphone and stops the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting microphone {self.microphone_index}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""
|
||||
Reads the last audio chunk recorded by the microphone, e.g. all samples recorded since the last read or since the beginning of the recording.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
audio_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return audio_readings
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
microphone_index,
|
||||
sample_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
sounddevice_sdk,
|
||||
) -> None:
|
||||
"""
|
||||
Process callback used to create an unpickable sounddevice audio input stream with a recording callback and start, stop and close it based on multiprocessing events.
|
||||
"""
|
||||
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def audio_callback(indata, frames, timestamp, status) -> None:
|
||||
"""
|
||||
Low-level sounddevice callback.
|
||||
"""
|
||||
if status:
|
||||
logger.warning(status)
|
||||
if audio_callback_start_event.is_set():
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
# Create the audio stream
|
||||
# InputStream must be instantiated in the process as it is not pickable.
|
||||
stream = sounddevice_sdk.InputStream(
|
||||
device=microphone_index,
|
||||
samplerate=sample_rate,
|
||||
channels=max(channels),
|
||||
dtype="float32",
|
||||
blocksize=0, # Varying input buffer length, but no additional latency
|
||||
latency="low", # Low latency mode (not enabled by default !)
|
||||
# never_drop_input=True, # Disabled as it generates an error for some devices
|
||||
callback=audio_callback,
|
||||
)
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
stream.start()
|
||||
record_is_started_event.set()
|
||||
record_stop_event.wait()
|
||||
stream.stop() # stream.stop() waits for all buffers to be processed, stream.abort() flushes the buffers !
|
||||
record_is_started_event.clear()
|
||||
stream.close()
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Starts the recording of the microphone. If output_file is provided, the audio will be written to this file.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Microphone {self.microphone_index} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=PortAudioMicrophone._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for microphone {self.microphone_index}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for microphone {self.microphone_index}.")
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""
|
||||
Stops the recording of the microphones.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Microphone {self.microphone_index} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Microphone {self.microphone_index} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
# Wait for the stream to be stopped (might lead to race condition if the stream is not properly stopped on array reset and queue clearing)
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for microphone {self.microphone_index}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for microphone {self.microphone_index}.")
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="FLOAT", # By default, a much lower quality WAV file is created !
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
from .sensor_touchlab import TouchLabSensor
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import MicrophoneConfig
|
||||
|
||||
|
||||
@MicrophoneConfig.register_subclass("touchlab")
|
||||
@dataclass
|
||||
class TouchLabSensorConfig(MicrophoneConfig):
|
||||
"""Configuration class for TouchLab tactile sensors (technically not a microphone, but behaves like one acquisition-wise).
|
||||
|
||||
This class provides configuration options for TouchLab tactile sensors, including serial port, sample rate and channels.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configurations
|
||||
TouchLabSensorConfig("/dev/ttyACM0", 16000) # Serial port /dev/ttyACM0, 16000Hz
|
||||
TouchLabSensorConfig("/dev/ttyACM1", 44100) # Serial port /dev/ttyACM1, 44100Hz
|
||||
```
|
||||
|
||||
Attributes:
|
||||
sensor_port: Serial port of the tactile sensor.
|
||||
baud_rate: Baud rate of the tactile sensor.
|
||||
sample_rate: Sample rate in Hz for the tactile sensor.
|
||||
channels: List of channel numbers to use for the tactile sensor.
|
||||
"""
|
||||
|
||||
sensor_port: str
|
||||
baud_rate: int = 115_200
|
||||
@@ -0,0 +1,469 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the TouchLabSensor class for capturing tactile data from TouchLab tactile sensors.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import (
|
||||
Event as process_Event,
|
||||
JoinableQueue as process_Queue,
|
||||
Process,
|
||||
)
|
||||
from pathlib import Path
|
||||
from queue import Empty
|
||||
from threading import Barrier, Event, Event as thread_Event, Thread
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from serial import Serial
|
||||
from soundfile import SoundFile
|
||||
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
from ..microphone import Microphone
|
||||
from .configuration_touchlab import TouchLabSensorConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_SERIAL_READ_SIZE = 512
|
||||
|
||||
|
||||
class TouchLabSensor(Microphone):
|
||||
"""
|
||||
The TouchLabSensor class handles all TouchLab tactile sensors.
|
||||
|
||||
A TouchLabSensor instance requires the serial port of the tactile sensor, which may be obtained using `python -m lerobot.find_port`. It also requires the recording sample rate as well as the list of recorded channels.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
from lerobot.common.robot_devices.microphones.configs import TouchLabSensorConfig
|
||||
|
||||
config = TouchLabSensorConfig(sensor_port="/dev/ttyACM0", baud_rate=115200, sample_rate=115, channels=[1])
|
||||
microphone = TouchLabSensor(config)
|
||||
|
||||
microphone.connect()
|
||||
microphone.start_recording("some/output/file.wav")
|
||||
...
|
||||
audio_readings = microphone.read() # Gets all recorded audio data since the last read or since the beginning of the recording. The longer the period the longer the reading time !
|
||||
...
|
||||
microphone.stop_recording()
|
||||
microphone.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: TouchLabSensorConfig):
|
||||
""" "
|
||||
Initializes the TouchLabSensor instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the sensor.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Sensor port
|
||||
self.sensor_port = config.sensor_port
|
||||
|
||||
# Baud rate
|
||||
self.baud_rate = config.baud_rate
|
||||
|
||||
# Input audio recording process and events
|
||||
self.record_process = None
|
||||
self.record_stop_event = process_Event()
|
||||
self.record_start_event = process_Event()
|
||||
self.record_close_event = process_Event()
|
||||
self.record_is_started_event = process_Event()
|
||||
self.audio_callback_start_event = process_Event()
|
||||
|
||||
# Process-safe concurrent queue to send audio from the recording process to the writing process/thread
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# SharedArray to store audio from the recording process.
|
||||
self.read_shared_array = None
|
||||
self.local_read_shared_array = None
|
||||
# Thread/Process to handle data writing in a separate thread/process (safely)
|
||||
self.write_thread = None
|
||||
self.write_stop_event = None
|
||||
self.write_is_started_event = None
|
||||
|
||||
self.logs = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.sensor_port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the sensor is currently connected.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is connected and ready to start recording,
|
||||
False otherwise.
|
||||
"""
|
||||
return self.record_process is not None and self.record_process.is_alive()
|
||||
|
||||
@property
|
||||
def is_recording(self) -> bool:
|
||||
"""Check if the sensor is currently recording.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is recording, False otherwise.
|
||||
"""
|
||||
return self.record_is_started_event.is_set()
|
||||
|
||||
@property
|
||||
def is_writing(self) -> bool:
|
||||
"""Check if the sensor is currently writing to a file.
|
||||
|
||||
Returns:
|
||||
bool: True if the sensor is writing to a file, False otherwise.
|
||||
"""
|
||||
return self.write_thread is not None and self.write_is_started_event.is_set()
|
||||
|
||||
@staticmethod
|
||||
def find_microphones() -> list[dict[str, Any]]:
|
||||
"""Detects available sensors connected to the system.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: A list of dictionaries,
|
||||
where each dictionary contains information about a detected sensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
Establish connection to the sensor.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"Sensor connected to {self.sensor_port} is already connected.")
|
||||
|
||||
# Create or reset queue and shared array
|
||||
self.read_shared_array = SharedArray(
|
||||
shape=(self.sample_rate * 10, len(self.channels)),
|
||||
dtype=np.dtype("int16"),
|
||||
)
|
||||
self.local_read_shared_array = self.read_shared_array.get_local_array()
|
||||
self.write_queue = process_Queue()
|
||||
|
||||
# Reset events
|
||||
self.record_start_event.clear()
|
||||
self.record_stop_event.clear()
|
||||
self.record_close_event.clear()
|
||||
self.record_is_started_event.clear()
|
||||
self.audio_callback_start_event.clear()
|
||||
|
||||
# Create and start an audio input stream with a recording callback
|
||||
# Remark: this is done in a separate process so that audio recording is not impacted by the main thread CPU usage, especially the precise_sleep function.
|
||||
process_init_event = process_Event()
|
||||
self.record_process = Process(
|
||||
target=self._record_process,
|
||||
args=(
|
||||
self.sensor_port,
|
||||
self.baud_rate,
|
||||
self.channels,
|
||||
process_init_event,
|
||||
self.record_start_event,
|
||||
self.record_stop_event,
|
||||
self.record_close_event,
|
||||
self.record_is_started_event,
|
||||
self.audio_callback_start_event,
|
||||
self.write_queue,
|
||||
self.read_shared_array,
|
||||
),
|
||||
)
|
||||
self.record_process.daemon = True
|
||||
self.record_process.start()
|
||||
|
||||
is_init = process_init_event.wait(
|
||||
timeout=5.0
|
||||
) # Wait for the recording process to be started, and to potentially raise an error on failure.
|
||||
if not self.is_connected or not is_init:
|
||||
raise RuntimeError(f"Error connecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@staticmethod
|
||||
def _record_process(
|
||||
sensor_port,
|
||||
baud_rate,
|
||||
channels,
|
||||
process_init_event,
|
||||
record_start_event,
|
||||
record_stop_event,
|
||||
record_close_event,
|
||||
record_is_started_event,
|
||||
audio_callback_start_event,
|
||||
write_queue,
|
||||
read_shared_array,
|
||||
) -> None:
|
||||
channels_index = np.array(channels) - 1
|
||||
local_read_shared_array = read_shared_array.get_local_array()
|
||||
|
||||
def tactile_callback(serial_connection):
|
||||
"""
|
||||
Parse the tactile data from the raw input data.
|
||||
"""
|
||||
buffer = serial_connection.readline()
|
||||
|
||||
if audio_callback_start_event.is_set():
|
||||
strings = buffer.decode("utf8").split(",")
|
||||
num_taxels = len(strings)
|
||||
|
||||
if num_taxels > 0 and num_taxels < MAX_SERIAL_READ_SIZE: # Make sure we didn't read rubbish
|
||||
indata = np.empty((1, num_taxels))
|
||||
for i in range(num_taxels):
|
||||
indata[0, i] = int(strings[i])
|
||||
|
||||
write_queue.put_nowait(indata[:, channels_index])
|
||||
read_shared_array.write(local_read_shared_array, indata[:, channels_index])
|
||||
|
||||
process_init_event.set()
|
||||
|
||||
while True:
|
||||
start_flag = record_start_event.wait(timeout=0.1)
|
||||
if record_close_event.is_set():
|
||||
break
|
||||
elif not start_flag:
|
||||
continue
|
||||
|
||||
with Serial(sensor_port, baud_rate, timeout=0.5) as serial_connection:
|
||||
serial_connection.flush()
|
||||
record_is_started_event.set()
|
||||
while not record_stop_event.is_set():
|
||||
tactile_callback(serial_connection)
|
||||
record_is_started_event.clear()
|
||||
serial_connection.close()
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect the sensor and release any resources.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
|
||||
if self.is_recording:
|
||||
self.stop_recording()
|
||||
|
||||
self.record_close_event.set()
|
||||
self.read_shared_array.delete()
|
||||
self.write_queue.close()
|
||||
self.record_process.join()
|
||||
|
||||
if self.is_connected:
|
||||
raise RuntimeError(f"Error disconnecting sensor connected to {self.sensor_port}.")
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
def start_recording(
|
||||
self,
|
||||
output_file: str | Path | None = None,
|
||||
multiprocessing: bool | None = False,
|
||||
overwrite: bool | None = True,
|
||||
barrier: Barrier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Start recording tactile data from the sensor.
|
||||
|
||||
Args:
|
||||
output_file: Optional path to save the recorded tactile data.
|
||||
multiprocessing: If True, enables multiprocessing for recording. Defaults to multithreading otherwise.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
barrier: If not None, ensures that multiple sensors start recording at the same time.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if self.is_recording:
|
||||
raise DeviceAlreadyRecordingError(f"Sensor connected to {self.sensor_port} is already recording.")
|
||||
|
||||
# Reset queue and shared memory
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue)
|
||||
|
||||
# Reset stop event
|
||||
self.record_stop_event.clear()
|
||||
|
||||
# Write recordings into a file if output_file is provided
|
||||
if output_file is not None:
|
||||
output_file = Path(output_file)
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if output_file.exists():
|
||||
if overwrite:
|
||||
output_file.unlink()
|
||||
else:
|
||||
raise FileExistsError(
|
||||
f"Output file {output_file} already exists. Set overwrite to True to overwrite it."
|
||||
)
|
||||
|
||||
if multiprocessing:
|
||||
self.write_stop_event = process_Event()
|
||||
self.write_is_started_event = process_Event()
|
||||
self.write_thread = Process(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.write_stop_event = thread_Event()
|
||||
self.write_is_started_event = thread_Event()
|
||||
self.write_thread = Thread(
|
||||
target=TouchLabSensor._write_loop,
|
||||
args=(
|
||||
self.write_queue,
|
||||
self.write_stop_event,
|
||||
self.write_is_started_event,
|
||||
self.sample_rate,
|
||||
self.channels,
|
||||
output_file,
|
||||
),
|
||||
)
|
||||
self.write_thread.daemon = True
|
||||
self.write_thread.start()
|
||||
self.write_is_started_event.wait() # Wait for the writing thread/process to be started.
|
||||
|
||||
self.record_start_event.set() # Start the input audio stream process
|
||||
self.record_is_started_event.wait() # Wait for the input audio stream process to be actually started
|
||||
|
||||
if barrier is not None:
|
||||
barrier.wait() # Wait for multiple input audio streams to be started at the same time
|
||||
|
||||
self.audio_callback_start_event.set()
|
||||
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Error starting recording for sensor connected to {self.sensor_port}.")
|
||||
if output_file is not None and not self.is_writing:
|
||||
raise RuntimeError(f"Error starting writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def _read(self) -> np.ndarray:
|
||||
"""
|
||||
Thread/Process-safe callback to read available audio data
|
||||
"""
|
||||
return self.read_shared_array.read(self.local_read_shared_array, flush=True)
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
"""Capture and return a single audio chunk from the sensor.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Captured audio chunk as a numpy array.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise RuntimeError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
tactile_readings = self._read()
|
||||
|
||||
# log the number of seconds it took to read the audio chunk
|
||||
self.logs["delta_timestamp_s"] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the audio chunk was received
|
||||
self.logs["timestamp_utc"] = time.perf_counter()
|
||||
|
||||
return tactile_readings
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""Internal loop run by the background thread for asynchronous reading."""
|
||||
|
||||
def stop_recording(self) -> None:
|
||||
"""Stop recording audio from the sensor."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"Sensor connected to {self.sensor_port} is not connected.")
|
||||
if not self.is_recording:
|
||||
raise DeviceNotRecordingError(f"Sensor connected to {self.sensor_port} is not recording.")
|
||||
|
||||
self.audio_callback_start_event.clear()
|
||||
self.record_start_event.clear() # Ensures the audio stream is not started again !
|
||||
self.record_stop_event.set()
|
||||
|
||||
self.read_shared_array.reset()
|
||||
self._clear_queue(self.write_queue, join_queue=True)
|
||||
|
||||
if self.is_writing:
|
||||
self.write_stop_event.set()
|
||||
self.write_thread.join()
|
||||
|
||||
timeout = 1.0
|
||||
while self.is_recording and timeout > 0:
|
||||
time.sleep(0.01)
|
||||
timeout -= 0.01
|
||||
|
||||
if self.is_recording:
|
||||
raise RuntimeError(f"Error stopping recording for sensor connected to {self.sensor_port}.")
|
||||
if self.is_writing:
|
||||
raise RuntimeError(f"Error stopping writing for sensor connected to {self.sensor_port}.")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
|
||||
@staticmethod
|
||||
def _clear_queue(queue, join_queue: bool = False):
|
||||
"""
|
||||
Clears the queue by getting all items until it is empty. The longer the queue, the longer it takes to clear it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
if join_queue:
|
||||
queue.join()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _write_loop(
|
||||
queue,
|
||||
write_stop_event: Event,
|
||||
write_is_started_event: Event,
|
||||
sample_rate: int,
|
||||
channels: list[int],
|
||||
output_file: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Thread/Process-safe loop to write audio data into a file.
|
||||
"""
|
||||
# Can only be run on a single process/thread for file writing safety
|
||||
with SoundFile(
|
||||
output_file,
|
||||
mode="w",
|
||||
samplerate=sample_rate,
|
||||
channels=len(channels),
|
||||
format="WAV",
|
||||
subtype="PCM_16", # Subtype for int16 values
|
||||
) as file:
|
||||
write_is_started_event.set()
|
||||
while not write_stop_event.is_set():
|
||||
try:
|
||||
file.write(
|
||||
queue.get(timeout=0.005)
|
||||
) # Timeout set as the usual sounddevice buffer size. get_nowait is not possible here as it saturates the thread.
|
||||
queue.task_done()
|
||||
except Empty:
|
||||
continue
|
||||
write_is_started_event.clear()
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from multiprocessing import Barrier
|
||||
from threading import Thread
|
||||
|
||||
from .configs import MicrophoneConfig
|
||||
from .microphone import Microphone
|
||||
|
||||
|
||||
def make_microphones_from_configs(microphone_configs: dict[str, MicrophoneConfig]) -> dict[str, Microphone]:
|
||||
microphones = {}
|
||||
|
||||
for key, cfg in microphone_configs.items():
|
||||
if cfg.type == "portaudio":
|
||||
from .portaudio import PortAudioMicrophone
|
||||
|
||||
microphones[key] = PortAudioMicrophone(cfg)
|
||||
elif cfg.type == "touchlab":
|
||||
from .touchlab import TouchLabSensor
|
||||
|
||||
microphones[key] = TouchLabSensor(cfg)
|
||||
elif cfg.type == "anyskin":
|
||||
from .anyskin import AnyskinSensor
|
||||
|
||||
microphones[key] = AnyskinSensor(cfg)
|
||||
else:
|
||||
raise ValueError(f"The microphone type '{cfg.type}' is not valid.")
|
||||
|
||||
return microphones
|
||||
|
||||
|
||||
def async_microphones_start_recording(
|
||||
microphones: dict[str, Microphone],
|
||||
output_files: list[str | None] | None = None,
|
||||
multiprocessing: bool = False,
|
||||
overwrite: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Starts recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
output_files: A list of output files.
|
||||
multiprocessing: If True, enables multiprocessing for recording.
|
||||
overwrite: If True, overwrites existing files at output_file path.
|
||||
"""
|
||||
|
||||
start_recording_threads = []
|
||||
if output_files is None:
|
||||
output_files = [None] * len(microphones)
|
||||
|
||||
barrier = Barrier(len(microphones))
|
||||
|
||||
for microphone, output_file in zip(microphones.values(), output_files, strict=False):
|
||||
start_recording_threads.append(
|
||||
Thread(target=microphone.start_recording, args=(output_file, multiprocessing, overwrite, barrier))
|
||||
)
|
||||
|
||||
for thread in start_recording_threads:
|
||||
thread.start()
|
||||
for thread in start_recording_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
def async_microphones_stop_recording(microphones: dict[str, Microphone]) -> None:
|
||||
"""
|
||||
Stops recording on multiple microphones asynchronously to avoid delays.
|
||||
|
||||
Args:
|
||||
microphones: A dictionary of microphones.
|
||||
"""
|
||||
|
||||
stop_recording_threads = []
|
||||
|
||||
for microphone in microphones.values():
|
||||
stop_recording_threads.append(Thread(target=microphone.stop_recording))
|
||||
|
||||
for thread in stop_recording_threads:
|
||||
thread.start()
|
||||
for thread in stop_recording_threads:
|
||||
thread.join()
|
||||
@@ -98,6 +98,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"AUDIO": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
@@ -108,6 +109,10 @@ class ACTConfig(PreTrainedConfig):
|
||||
vision_backbone: str = "resnet18"
|
||||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||||
replace_final_stride_with_dilation: int = False
|
||||
# Audio backbone.
|
||||
audio_backbone: str = vision_backbone
|
||||
pretrained_backbone_weights_audio: str | None = None
|
||||
replace_final_stride_with_dilation_audio: int = False
|
||||
# Transformer layers.
|
||||
pre_norm: bool = False
|
||||
dim_model: int = 512
|
||||
@@ -170,8 +175,10 @@ class ACTConfig(PreTrainedConfig):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features and not self.env_state_feature:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
if not (self.image_features or self.audio_features) and not self.env_state_feature:
|
||||
raise ValueError(
|
||||
"You must provide at least one image/audio or the environment state among the inputs."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> None:
|
||||
|
||||
@@ -35,7 +35,7 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import ACTION, OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTPolicy(PreTrainedPolicy):
|
||||
@@ -106,6 +106,8 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
@@ -331,12 +333,26 @@ class ACT(nn.Module):
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||
|
||||
# Backbone for audio feature extraction.
|
||||
if self.config.audio_features:
|
||||
audio_backbone_model = getattr(torchvision.models, config.audio_backbone)(
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation_audio],
|
||||
weights=config.pretrained_backbone_weights_audio,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final
|
||||
# feature map).
|
||||
# Note: The forward method of this returns a dict: {"feature_map": output}.
|
||||
self.audio_backbone = IntermediateLayerGetter(
|
||||
audio_backbone_model, return_layers={"layer4": "feature_map"}
|
||||
)
|
||||
|
||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||
self.encoder = ACTEncoder(config)
|
||||
self.decoder = ACTDecoder(config)
|
||||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
|
||||
# [latent, (robot_state), (env_state), (image_feature_map_pixels), (audio_feature)].
|
||||
if self.config.robot_state_feature:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
self.config.robot_state_feature.shape[0], config.dim_model
|
||||
@@ -350,6 +366,10 @@ class ACT(nn.Module):
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_input_proj = nn.Conv2d(
|
||||
audio_backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
n_1d_tokens = 1 # for the latent
|
||||
if self.config.robot_state_feature:
|
||||
@@ -359,6 +379,8 @@ class ACT(nn.Module):
|
||||
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
|
||||
if self.config.image_features:
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
if self.config.audio_features:
|
||||
self.encoder_audio_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||
@@ -483,6 +505,21 @@ class ACT(nn.Module):
|
||||
encoder_in_tokens.extend(list(cam_features))
|
||||
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||
|
||||
if self.config.audio_features:
|
||||
for audio in batch[OBS_AUDIO]:
|
||||
audio_features = self.audio_backbone(audio)["feature_map"]
|
||||
audio_pos_embed = self.encoder_audio_feat_pos_embed(audio_features).to(
|
||||
dtype=audio_features.dtype
|
||||
)
|
||||
audio_features = self.encoder_audio_feat_input_proj(audio_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
audio_features = einops.rearrange(audio_features, "b c h w -> (h w) b c")
|
||||
audio_pos_embed = einops.rearrange(audio_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
encoder_in_tokens.extend(list(audio_features))
|
||||
encoder_in_pos_embed.extend(list(audio_pos_embed))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0)
|
||||
|
||||
@@ -17,9 +17,11 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
AudioProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -63,6 +65,15 @@ def make_act_pre_post_processors(
|
||||
stats=dataset_stats,
|
||||
device=config.device,
|
||||
),
|
||||
AudioProcessorStep(
|
||||
output_height=224,
|
||||
output_width=224,
|
||||
output_channels=3,
|
||||
input_audio_chunk_duration=DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
input_sample_rate=48000,
|
||||
intermediate_sample_rate=16000,
|
||||
n_fft=1024,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
|
||||
@@ -35,7 +35,6 @@ from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
@@ -68,7 +67,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "pi05_video", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
return PI05Policy
|
||||
elif name == "pi05_video":
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
|
||||
return PI05VideoPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
@@ -152,7 +147,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "pi05_video", "sac", "smolvla",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
@@ -174,8 +169,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi05":
|
||||
return PI05Config(**kwargs)
|
||||
elif policy_type == "pi05_video":
|
||||
return PI05VideoConfig(**kwargs)
|
||||
elif policy_type == "sac":
|
||||
return SACConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
@@ -340,14 +333,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI05VideoConfig):
|
||||
from lerobot.policies.videovla.processor_pi05 import make_pi05_video_pre_post_processors
|
||||
|
||||
processors = make_pi05_video_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
|
||||
@@ -460,8 +460,8 @@ class PaliGemmaWithExpertModel(
|
||||
inputs_embeds=inputs_embeds[1],
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
past_key_values=None, #jadechoghari
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond[1] if adarms_cond is not None else None,
|
||||
)
|
||||
suffix_output = suffix_output.last_hidden_state
|
||||
@@ -575,13 +575,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
# try:
|
||||
# from transformers.models.siglip import check
|
||||
try:
|
||||
from transformers.models.siglip import check
|
||||
|
||||
# if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
# raise ValueError(msg)
|
||||
# except ImportError:
|
||||
# raise ValueError(msg) from None
|
||||
if not check.check_whether_transformers_replace_is_installed_correctly():
|
||||
raise ValueError(msg)
|
||||
except ImportError:
|
||||
raise ValueError(msg) from None
|
||||
|
||||
def gradient_checkpointing_enable(self):
|
||||
"""Enable gradient checkpointing for memory optimization."""
|
||||
|
||||
@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
|
||||
This function takes a dictionary of NumPy arrays, performs necessary
|
||||
preprocessing, and prepares it for model inference. The steps include:
|
||||
1. Converting NumPy arrays to PyTorch tensors.
|
||||
2. Normalizing and permuting image data (if any).
|
||||
2. Normalizing and permuting image data and audio data (if any).
|
||||
3. Adding a batch dimension to each tensor.
|
||||
4. Moving all tensors to the specified compute device.
|
||||
5. Adding task and robot type information to the dictionary.
|
||||
@@ -129,6 +129,9 @@ def prepare_observation_for_inference(
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
elif "audio" in name:
|
||||
observation[name] = observation[name].type(torch.float32)
|
||||
observation[name] = observation[name].permute(1, 0).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
# π₀.₅ (pi05)
|
||||
|
||||
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
|
||||
It is designed as a **Vision-Language-Action model with open-world generalization**.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | π₀ | π₀.₅ |
|
||||
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
|
||||
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
|
||||
| AdaRMS | Not used | Used in action expert |
|
||||
| Tokenizer Length | 48 tokens | 200 tokens |
|
||||
| Discrete State Input | False (Uses `state_proj` layer) | True |
|
||||
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
|
||||
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
|
||||
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
|
||||
year = {2025},
|
||||
eprint = {2504.16054},
|
||||
archivePrefix= {arXiv},
|
||||
primaryClass = {cs.LG},
|
||||
url = {https://arxiv.org/abs/2504.16054},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lazy imports to avoid conflicts with lerobot.policies.pi05.PI05Config
|
||||
# when only importing subpackages like videoprism
|
||||
def __getattr__(name):
|
||||
if name == "PI05VideoConfig":
|
||||
from .configuration_pi05 import PI05VideoConfig
|
||||
return PI05VideoConfig
|
||||
elif name == "PI05VideoPolicy":
|
||||
from .modeling_pi05 import PI05VideoPolicy
|
||||
return PI05VideoPolicy
|
||||
elif name == "make_pi05_video_pre_post_processors":
|
||||
from .processor_pi05 import make_pi05_video_pre_post_processors
|
||||
return make_pi05_video_pre_post_processors
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
__all__ = ["PI05VideoConfig", "PI05VideoPolicy", "make_pi05_video_pre_post_processors"]
|
||||
@@ -1,212 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05_video")
|
||||
@dataclass
|
||||
class PI05VideoConfig(PreTrainedConfig):
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
action_expert_variant: str = "gemma_300m"
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
|
||||
n_action_steps: int = 50 # Number of action steps to execute
|
||||
|
||||
# Video encoder settings (VideoPrism)
|
||||
use_video_encoder: bool = False # Enable video encoding with VideoPrism
|
||||
video_num_frames: int = 16 # Number of frames for video encoding (VideoPrism default is 16)
|
||||
videoprism_model_name: str = "MHRDYN7/videoprism-base-f16r288" # VideoPrism model to use
|
||||
videoprism_image_size: int = 288 # VideoPrism expects 288x288 images
|
||||
freeze_video_encoder: bool = True # Whether to freeze the video encoder weights
|
||||
video_padding_mode: str = "repeat" # How to pad frames at episode start: "repeat" or "zero"
|
||||
# Which camera to use for video encoding (None = first camera, or specify key like "observation.images.top")
|
||||
video_encoder_camera_key: str | None = None
|
||||
# Perceiver Resampler settings to reduce video tokens (4096 -> video_num_latents)
|
||||
video_num_latents: int = 256 # Number of latent tokens for video resampler
|
||||
video_resampler_num_heads: int = 8 # Number of attention heads in resampler
|
||||
|
||||
# Shorter state and action vectors will be padded to these dimensions
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Flow matching parameters: see openpi `PI0Pytorch`
|
||||
num_inference_steps: int = 10
|
||||
time_sampling_beta_alpha: float = 1.5
|
||||
time_sampling_beta_beta: float = 1.0
|
||||
time_sampling_scale: float = 0.999
|
||||
time_sampling_offset: float = 0.001
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
|
||||
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
|
||||
}
|
||||
)
|
||||
|
||||
# Training settings
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = False # Freeze only the vision encoder
|
||||
train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
# Scheduler settings: see openpi `CosineDecaySchedule`
|
||||
# Note: These will auto-scale if --steps < scheduler_decay_steps
|
||||
# For example, --steps=3000 will scale warmup to 100 and decay to 3000
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
tokenizer_max_length: int = 200 # see openpi `__post_init__`
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Validate configuration
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
|
||||
|
||||
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
|
||||
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
|
||||
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
|
||||
# Validate video encoder settings
|
||||
if self.use_video_encoder:
|
||||
if self.video_num_frames < 1:
|
||||
raise ValueError(f"video_num_frames must be >= 1, got {self.video_num_frames}")
|
||||
if self.videoprism_image_size < 1:
|
||||
raise ValueError(f"videoprism_image_size must be >= 1, got {self.videoprism_image_size}")
|
||||
if self.video_padding_mode not in ["repeat", "zero"]:
|
||||
raise ValueError(
|
||||
f"video_padding_mode must be 'repeat' or 'zero', got {self.video_padding_mode}"
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features."""
|
||||
for i in range(self.empty_cameras):
|
||||
key = OBS_IMAGES + f".empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, *self.image_resolution), # Use configured image resolution
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.max_state_dim,), # Padded to max_state_dim
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.max_action_dim,), # Padded to max_action_dim
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta observations.
|
||||
|
||||
For PI05, we don't use generic observation_delta_indices because it would
|
||||
apply to both images AND state. Instead, we use image_observation_delta_indices
|
||||
which only applies to image observations.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def image_observation_delta_indices(self) -> list[int] | None:
|
||||
"""Return indices for delta image observations only.
|
||||
|
||||
When video encoding is enabled, returns indices for the past frames
|
||||
needed by VideoPrism (e.g., -15, -14, ..., -1, 0 for 16 frames).
|
||||
This only applies to image observations, not state.
|
||||
"""
|
||||
if self.use_video_encoder:
|
||||
# Return indices for past frames: [-15, -14, ..., -1, 0] for 16 frames
|
||||
return list(range(-(self.video_num_frames - 1), 1))
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
|
||||
@dataclass
|
||||
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
|
||||
"""
|
||||
Processor step to prepare the state and tokenize the language input.
|
||||
"""
|
||||
|
||||
max_state_dim: int = 32
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
|
||||
if state is None:
|
||||
raise ValueError("State is required for PI05")
|
||||
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
# TODO: check if this necessary
|
||||
state = deepcopy(state)
|
||||
|
||||
# Prepare state (pad to max_state_dim)
|
||||
state = pad_vector(state, self.max_state_dim)
|
||||
|
||||
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
state_np = state.cpu().numpy()
|
||||
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
full_prompts = []
|
||||
for i, task in enumerate(tasks):
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
state_str = " ".join(map(str, discretized_states[i]))
|
||||
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
|
||||
full_prompts.append(full_prompt)
|
||||
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
|
||||
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
|
||||
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
This step does not alter the feature definitions.
|
||||
"""
|
||||
return features
|
||||
|
||||
|
||||
def make_pi05_video_pre_post_processors(
|
||||
config: PI05VideoConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the PI05Video policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations.
|
||||
2. Normalizing input and output features based on dataset statistics.
|
||||
3. Adding a batch dimension.
|
||||
4. Appending a newline character to the task description for tokenizer compatibility.
|
||||
5. Tokenizing the text prompt using the PaliGemma tokenizer.
|
||||
6. Moving all data to the specified device.
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Moving data to the CPU.
|
||||
2. Unnormalizing the output features to their original scale.
|
||||
|
||||
Args:
|
||||
config: The configuration object for the PI0 policy.
|
||||
dataset_stats: A dictionary of statistics for normalization.
|
||||
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
|
||||
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines.
|
||||
"""
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
|
||||
AddBatchDimensionProcessorStep(),
|
||||
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
|
||||
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,214 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script for PI05 with video encoder (VideoPrism).
|
||||
|
||||
This script creates a dummy example to test the model with video encoding enabled.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.videovla.configuration_pi05 import PI05VideoConfig
|
||||
from lerobot.policies.videovla.modeling_pi05 import PI05VideoPolicy
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
def create_dummy_batch(
|
||||
batch_size: int = 2,
|
||||
num_frames: int = 16,
|
||||
image_size: int = 224,
|
||||
num_cameras: int = 2,
|
||||
state_dim: int = 14,
|
||||
action_dim: int = 14,
|
||||
chunk_size: int = 50,
|
||||
seq_len: int = 10,
|
||||
device: str = "cuda",
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create a dummy batch for testing."""
|
||||
batch = {}
|
||||
|
||||
# Create image observations with temporal dimension [B, T, C, H, W]
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
# Images in [0, 1] range
|
||||
batch[key] = torch.rand(batch_size, num_frames, 3, image_size, image_size, device=device)
|
||||
|
||||
# Create state observation [B, state_dim]
|
||||
batch[OBS_STATE] = torch.rand(batch_size, state_dim, device=device)
|
||||
|
||||
# Create language tokens and attention mask [B, seq_len]
|
||||
batch["observation.language.tokens"] = torch.randint(0, 1000, (batch_size, seq_len), device=device)
|
||||
batch["observation.language.attention_mask"] = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
|
||||
# Create action targets [B, chunk_size, action_dim]
|
||||
batch[ACTION] = torch.rand(batch_size, chunk_size, action_dim, device=device)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def test_video_encoder():
|
||||
"""Test the PI05 model with video encoding enabled."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Configuration
|
||||
batch_size = 2
|
||||
num_frames = 16
|
||||
image_size = 224
|
||||
num_cameras = 2
|
||||
state_dim = 14
|
||||
action_dim = 14
|
||||
chunk_size = 50
|
||||
|
||||
# Create config with video encoder enabled
|
||||
print("Creating PI05VideoConfig with video encoder...")
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=num_frames,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
videoprism_image_size=288,
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
video_encoder_camera_key=f"{OBS_IMAGES}.camera_0", # Use first camera for video
|
||||
chunk_size=chunk_size,
|
||||
max_action_dim=32,
|
||||
max_state_dim=32,
|
||||
dtype="float32", # Use float32 for testing
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up input/output features
|
||||
for i in range(num_cameras):
|
||||
key = f"{OBS_IMAGES}.camera_{i}"
|
||||
config.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, image_size, image_size),
|
||||
)
|
||||
|
||||
config.input_features[OBS_STATE] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(state_dim,),
|
||||
)
|
||||
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(action_dim,),
|
||||
)
|
||||
|
||||
print(f"use_video_encoder: {config.use_video_encoder}")
|
||||
print(f"video_num_frames: {config.video_num_frames}")
|
||||
print(f"video_padding_mode: {config.video_padding_mode}")
|
||||
print(f"video_encoder_camera_key: {config.video_encoder_camera_key}")
|
||||
print(f"image_observation_delta_indices: {config.image_observation_delta_indices}")
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Create dummy batch
|
||||
batch = create_dummy_batch(
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
image_size=image_size,
|
||||
num_cameras=num_cameras,
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
chunk_size=chunk_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
print(f"Batch keys: {list(batch.keys())}" )
|
||||
for key, value in batch.items():
|
||||
print(f"{key}: {value.shape}")
|
||||
|
||||
# Test forward pass
|
||||
model.train()
|
||||
try:
|
||||
loss, loss_dict = model.forward(batch)
|
||||
print(f"Forward pass successful!")
|
||||
print(f"Loss: {loss.item():.4f}")
|
||||
print(f"Loss dict: {loss_dict}")
|
||||
except Exception as e:
|
||||
print(f"Forward pass failed: {e}")
|
||||
raise
|
||||
|
||||
# Test inference
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
try:
|
||||
actions = model.predict_action_chunk(batch)
|
||||
print(f"Test pass, inference pass!")
|
||||
print(f"Predicted actions shape: {actions.shape}")
|
||||
except Exception as e:
|
||||
print(f"Inference failed: {e}")
|
||||
raise
|
||||
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
def test_frame_padding():
|
||||
"""Test frame padding at episode start."""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# Create config
|
||||
config = PI05VideoConfig(
|
||||
use_video_encoder=True,
|
||||
video_num_frames=16,
|
||||
videoprism_model_name="MHRDYN7/videoprism-base-f16r288",
|
||||
freeze_video_encoder=True,
|
||||
video_padding_mode="repeat",
|
||||
chunk_size=50,
|
||||
dtype="float32",
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set up minimal features
|
||||
config.input_features[f"{OBS_IMAGES}.camera_0"] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 224, 224),
|
||||
)
|
||||
config.output_features[ACTION] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(14,),
|
||||
)
|
||||
|
||||
# Create model
|
||||
model = PI05VideoPolicy(config)
|
||||
model.to(device)
|
||||
|
||||
# Test with fewer frames than expected (simulating episode start)
|
||||
batch = {
|
||||
f"{OBS_IMAGES}.camera_0": torch.rand(2, 5, 3, 224, 224, device=device),
|
||||
"observation.language.tokens": torch.randint(0, 1000, (2, 10), device=device),
|
||||
"observation.language.attention_mask": torch.ones(2, 10, dtype=torch.bool, device=device),
|
||||
ACTION: torch.rand(2, 50, 14, device=device),
|
||||
}
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input frames: 5")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Frame padding test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
# Test with single frame
|
||||
batch[f"{OBS_IMAGES}.camera_0"] = torch.rand(2, 3, 224, 224, device=device) # [B, C, H, W]
|
||||
|
||||
video_frames = model._preprocess_video(batch)
|
||||
if video_frames is not None:
|
||||
print(f"Input: single frame [B, C, H, W]")
|
||||
print(f"Output video_frames shape: {video_frames.shape}")
|
||||
print(f"Expected: [2, 16, 3, 224, 224]")
|
||||
assert video_frames.shape == (2, 16, 3, 224, 224), f"Unexpected shape: {video_frames.shape}"
|
||||
print("Single frame expansion test PASSED!")
|
||||
else:
|
||||
print("video_frames is None (unexpected)")
|
||||
|
||||
print("All tests passed!")
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
test_frame_padding()
|
||||
test_video_encoder()
|
||||
@@ -1,37 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
|
||||
from .modeling_videoprism import (
|
||||
VideoPrismClipModel,
|
||||
VideoPrismForVideoClassification,
|
||||
VideoPrismPreTrainedModel,
|
||||
VideoPrismTextModel,
|
||||
VideoPrismVideoModel,
|
||||
VideoPrismVisionModel,
|
||||
)
|
||||
from .video_processing_videoprism import VideoPrismVideoProcessor
|
||||
|
||||
__all__ = [
|
||||
"VideoPrismConfig",
|
||||
"VideoPrismTextConfig",
|
||||
"VideoPrismVisionConfig",
|
||||
"VideoPrismClipModel",
|
||||
"VideoPrismForVideoClassification",
|
||||
"VideoPrismPreTrainedModel",
|
||||
"VideoPrismTextModel",
|
||||
"VideoPrismVideoModel",
|
||||
"VideoPrismVisionModel",
|
||||
"VideoPrismVideoProcessor",
|
||||
]
|
||||
@@ -1,269 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class VideoPrismVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismVisionModel`]. It is used to instantiate a
|
||||
VideoPrism vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
image_size (`int`, *optional*, defaults to 288):
|
||||
The size of the input image.
|
||||
num_frames (`int`, *optional*, defaults to 16):
|
||||
The number of frames in the input video.
|
||||
tubelet_size (`List[int]`, *optional*, defaults to `[1, 18, 18]`):
|
||||
The size of the tubelet patch.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_spatial_layers (`int`, *optional*, defaults to 12):
|
||||
Number of spatial transformer blocks.
|
||||
num_temporal_layers (`int`, *optional*, defaults to 4):
|
||||
Number of temporal transformer blocks.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_python"`):
|
||||
The non-linear activation function (function or string).
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the qkv projections in attention layers.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
Softcapping constant for attention logits.
|
||||
num_auxiliary_layers (`int`, *optional*, defaults to 2):
|
||||
Number of auxiliary layers. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
|
||||
apply_l2_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply L2 normalization to the output. This is used in the VideoPrismVideoModel that is a part of VideoPrismClipModel.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVisionConfig, VideoPrismVisionModel
|
||||
|
||||
>>> # Initializing a VideoPrismVisionConfig with default values
|
||||
>>> configuration = VideoPrismVisionConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismVisionModel with the configuration
|
||||
>>> model = VideoPrismVisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism_vision_model"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size=288,
|
||||
num_frames=16,
|
||||
tubelet_size=[1, 18, 18],
|
||||
num_channels=3,
|
||||
hidden_size=768,
|
||||
num_spatial_layers=12,
|
||||
num_temporal_layers=4,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu_python",
|
||||
hidden_dropout_prob=0.0,
|
||||
attention_probs_dropout_prob=0.0,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-06,
|
||||
qkv_bias=True,
|
||||
attn_logit_softcapping=50.0,
|
||||
num_auxiliary_layers=2,
|
||||
apply_l2_norm=True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
|
||||
self.image_size = image_size
|
||||
self.num_frames = num_frames
|
||||
self.tubelet_size = tubelet_size
|
||||
self.num_channels = num_channels
|
||||
self.qkv_bias = qkv_bias
|
||||
self.num_spatial_layers = num_spatial_layers
|
||||
self.num_temporal_layers = num_temporal_layers
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
self.num_auxiliary_layers = num_auxiliary_layers
|
||||
self.apply_l2_norm = apply_l2_norm
|
||||
|
||||
|
||||
class VideoPrismTextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismTextModel`]. It is used to instantiate a
|
||||
VideoPrism text encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_text_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the text Transformer encoder.
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the text model. Defines the number of different tokens that can be represented by the
|
||||
`input_ids` passed when calling [`VideoPrismTextModel`].
|
||||
apply_l2_norm (`bool`, *optional*, defaults to `True`):
|
||||
Whether to apply L2 normalization to the output text embeddings.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"relu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
qkv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to add a bias to the query, key, and value projections in the attention layers.
|
||||
hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
|
||||
Softcapping constant for attention logits.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismTextConfig, VideoPrismTextModel
|
||||
|
||||
>>> # Initializing a VideoPrismTextConfig with default values
|
||||
>>> configuration = VideoPrismTextConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismTextModel (with random weights) from the configuration
|
||||
>>> model = VideoPrismTextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism_text_model"
|
||||
base_config_key = "text_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_attention_heads=12,
|
||||
num_text_layers=12,
|
||||
vocab_size=32000,
|
||||
apply_l2_norm=True,
|
||||
hidden_act="relu",
|
||||
attention_probs_dropout_prob=0.0,
|
||||
qkv_bias=True,
|
||||
hidden_dropout_prob=0.0,
|
||||
layer_norm_eps=1e-06,
|
||||
initializer_range=0.02,
|
||||
attn_logit_softcapping=50.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_text_layers = num_text_layers
|
||||
self.vocab_size = vocab_size
|
||||
self.apply_l2_norm = apply_l2_norm
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.qkv_bias = qkv_bias
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.attn_logit_softcapping = attn_logit_softcapping
|
||||
|
||||
|
||||
class VideoPrismConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`VideoPrismModel`]. It is used to instantiate a
|
||||
VideoPrism model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the VideoPrism
|
||||
[google/videoprism](https://huggingface.co/google/videoprism) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
text_config (`VideoPrismTextConfig`, *optional*):
|
||||
Configuration for the text model.
|
||||
vision_config (`VideoPrismVisionConfig`, *optional*):
|
||||
Configuration for the vision model.
|
||||
kwargs (*optional*):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismConfig, VideoPrismModel
|
||||
|
||||
>>> # Initializing a VideoPrismConfig with default values
|
||||
>>> configuration = VideoPrismConfig()
|
||||
|
||||
>>> # Initializing a VideoPrismClipModel with the configuration
|
||||
>>> model = VideoPrismClipModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "videoprism"
|
||||
sub_configs = {"text_config": VideoPrismTextConfig, "vision_config": VideoPrismVisionConfig}
|
||||
|
||||
def __init__(self, text_config=None, vision_config=None, **kwargs):
|
||||
if text_config is None:
|
||||
text_config = VideoPrismTextConfig()
|
||||
logger.info("`text_config` is `None`. Initializing the `VideoPrismTextConfig` with default values.")
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = VideoPrismTextConfig(**text_config)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = VideoPrismVisionConfig()
|
||||
logger.info("`vision_config` is `None`. initializing the `VideoPrismVisionConfig` with default values.")
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = VideoPrismVisionConfig(**vision_config)
|
||||
|
||||
self.text_config = text_config
|
||||
self.vision_config = vision_config
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = ["VideoPrismVisionConfig", "VideoPrismTextConfig", "VideoPrismConfig"]
|
||||
@@ -1,245 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Record all the torch primitives in advance, so that we can use them without them being modified when we patch torch
|
||||
# in context managers
|
||||
TORCH_INIT_FUNCTIONS = {
|
||||
"uniform_": torch.nn.init.uniform_,
|
||||
"normal_": torch.nn.init.normal_,
|
||||
"constant_": torch.nn.init.constant_,
|
||||
"ones_": torch.nn.init.ones_,
|
||||
"zeros_": torch.nn.init.zeros_,
|
||||
"eye_": torch.nn.init.eye_,
|
||||
"dirac_": torch.nn.init.dirac_,
|
||||
"xavier_uniform_": torch.nn.init.xavier_uniform_,
|
||||
"xavier_normal_": torch.nn.init.xavier_normal_,
|
||||
"kaiming_uniform_": torch.nn.init.kaiming_uniform_,
|
||||
"kaiming_normal_": torch.nn.init.kaiming_normal_,
|
||||
"trunc_normal_": torch.nn.init.trunc_normal_,
|
||||
"orthogonal_": torch.nn.init.orthogonal_,
|
||||
"sparse_": torch.nn.init.sparse_,
|
||||
}
|
||||
|
||||
|
||||
def uniform_(
|
||||
tensor: torch.Tensor, a: float = 0.0, b: float = 1.0, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["uniform_"](tensor, a=a, b=b, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def normal_(
|
||||
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["normal_"](tensor, mean=mean, std=std, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def constant_(tensor: torch.Tensor, val: float) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["constant_"](tensor, val=val)
|
||||
return tensor
|
||||
|
||||
|
||||
def ones_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["ones_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def zeros_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["zeros_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def eye_(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["eye_"](tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def dirac_(tensor: torch.Tensor, groups: int = 1) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["dirac_"](tensor, groups=groups)
|
||||
return tensor
|
||||
|
||||
|
||||
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["xavier_uniform_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["xavier_normal_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def kaiming_uniform_(
|
||||
tensor: torch.Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["kaiming_uniform_"](
|
||||
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def kaiming_normal_(
|
||||
tensor: torch.Tensor,
|
||||
a: float = 0,
|
||||
mode: str = "fan_in",
|
||||
nonlinearity: str = "leaky_relu",
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["kaiming_normal_"](
|
||||
tensor, a=a, mode=mode, nonlinearity=nonlinearity, generator=generator
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(
|
||||
tensor: torch.Tensor,
|
||||
mean: float = 0.0,
|
||||
std: float = 1.0,
|
||||
a: float = -2.0,
|
||||
b: float = 2.0,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["trunc_normal_"](tensor, mean=mean, std=std, a=a, b=b, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def orthogonal_(
|
||||
tensor: torch.Tensor,
|
||||
gain: float = 1,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["orthogonal_"](tensor, gain=gain, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def sparse_(
|
||||
tensor: torch.Tensor, sparsity: float, std: float = 0.01, generator: torch.Generator | None = None
|
||||
) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
return TORCH_INIT_FUNCTIONS["sparse_"](tensor, sparsity=sparsity, std=std, generator=generator)
|
||||
return tensor
|
||||
|
||||
|
||||
def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
|
||||
if not getattr(tensor, "_is_hf_initialized", False):
|
||||
with torch.no_grad():
|
||||
return tensor.copy_(other)
|
||||
return tensor
|
||||
|
||||
|
||||
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
|
||||
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
|
||||
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
|
||||
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
|
||||
# The following list should be enough for all torch versions we work with
|
||||
TORCH_MODULES_TO_PATCH = (
|
||||
"torch.nn.init",
|
||||
"torch.nn.modules.activation",
|
||||
"torch.nn.modules.transformer",
|
||||
"torch.nn.modules.linear",
|
||||
"torch.nn.modules.loss",
|
||||
"torch.nn.modules.batchnorm",
|
||||
"torch.nn.modules.conv",
|
||||
"torch.nn.modules.normalization",
|
||||
"torch.nn.modules.rnn",
|
||||
"torch.nn.modules.sparse",
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def guard_torch_init_functions():
|
||||
"""
|
||||
Guard the `torch.nn.init` primitive functions to behave exactly like the functions in this file, i.e. be
|
||||
protected against the `_is_hf_initialized` flag to avoid re-init if the param was already loaded.
|
||||
|
||||
Usually, all models are using the init from `transformers` which are already guarded, but just to make extra sure
|
||||
and for remote code, we also use this context manager.
|
||||
"""
|
||||
originals = defaultdict(dict)
|
||||
try:
|
||||
# Replace all torch funcs by the ones in this file
|
||||
for module_name in TORCH_MODULES_TO_PATCH:
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
||||
if hasattr(module, func_name):
|
||||
originals[module][func_name] = getattr(module, func_name)
|
||||
setattr(module, func_name, globals()[func_name])
|
||||
yield
|
||||
finally:
|
||||
# Set back the original functions on all modules
|
||||
for module, functions in originals.items():
|
||||
for func_name, func in functions.items():
|
||||
setattr(module, func_name, func)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights():
|
||||
"""
|
||||
Disable weight initialization both at the torch-level, and at the transformers-level (`init_weights`).
|
||||
This is used to speed-up initializing an empty model with deepspeed, as we do not initialize the model on meta device
|
||||
with deepspeed, but we still don't need to run expensive weight initializations as we are loading params afterwards.
|
||||
"""
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
def empty_func(*args, **kwargs):
|
||||
pass
|
||||
|
||||
originals = defaultdict(dict)
|
||||
try:
|
||||
# Replace all torch funcs by empty ones
|
||||
for module_name in TORCH_MODULES_TO_PATCH:
|
||||
if module_name in sys.modules:
|
||||
module = sys.modules[module_name]
|
||||
for func_name in TORCH_INIT_FUNCTIONS.keys():
|
||||
if hasattr(module, func_name):
|
||||
originals[module][func_name] = getattr(module, func_name)
|
||||
setattr(module, func_name, empty_func)
|
||||
|
||||
# Also patch our own `init_weights`
|
||||
original_init_weights = PreTrainedModel.init_weights
|
||||
PreTrainedModel.init_weights = empty_func
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Set back the original torch functions on all modules
|
||||
for module, functions in originals.items():
|
||||
for func_name, func in functions.items():
|
||||
setattr(module, func_name, func)
|
||||
# Set back `init_weights`
|
||||
PreTrainedModel.init_weights = original_init_weights
|
||||
@@ -1,994 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from . import initialization as init
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.masking_utils import create_causal_mask
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
from transformers.modeling_outputs import BaseModelOutput, ImageClassifierOutput
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from transformers.file_utils import ModelOutput
|
||||
|
||||
from .configuration_videoprism import VideoPrismConfig, VideoPrismTextConfig, VideoPrismVisionConfig
|
||||
|
||||
def torch_int(x):
|
||||
"""
|
||||
Casts an input to a torch int64 tensor if we are in a tracing context, otherwise to a Python int.
|
||||
"""
|
||||
if not torch.is_available():
|
||||
return int(x)
|
||||
|
||||
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
|
||||
|
||||
@dataclass
|
||||
class BaseModelOutputWithSpatialAndTemporalStates(ModelOutput):
|
||||
"""
|
||||
Base class for model outputs that include spatial and temporal states.
|
||||
|
||||
Args:
|
||||
last_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden state of the model, typically of shape
|
||||
(batch_size, num_patches * num_frames, hidden_size).
|
||||
|
||||
temporal_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden_state of the temporal encoder, typically of shape
|
||||
(batch_size * num_patches, num_frames, hidden_size).
|
||||
|
||||
spatial_hidden_state (Optional[torch.FloatTensor]):
|
||||
The last hidden_state of the spatial encoder, typically of shape
|
||||
(batch_size * num_frames, num_patches, hidden_size).
|
||||
"""
|
||||
|
||||
last_hidden_state: torch.FloatTensor | None = None
|
||||
temporal_hidden_state: torch.FloatTensor | None = None
|
||||
spatial_hidden_state: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoPrismClipOutput(ModelOutput):
|
||||
"""
|
||||
Base class for VideoPrismClip model outputs.
|
||||
"""
|
||||
|
||||
logits_per_video: torch.FloatTensor | None = None
|
||||
logits_per_text: torch.FloatTensor | None = None
|
||||
video_embeds: torch.FloatTensor | None = None
|
||||
text_embeds: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoPrismVideoOutput(ModelOutput):
|
||||
"""
|
||||
Base class for VideoPrismVideo model outputs.
|
||||
"""
|
||||
|
||||
video_last_hidden_state: torch.FloatTensor | None = None
|
||||
auxiliary_output: torch.FloatTensor | None = None
|
||||
attention_pooling_output: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class VideoPrismTubeletEmbeddings(nn.Module):
|
||||
"""
|
||||
Construct VideoPrism Tubelet embeddings.
|
||||
|
||||
This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of
|
||||
shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
|
||||
|
||||
The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) *
|
||||
(width // tubelet_size[2]).
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_frames = config.num_frames
|
||||
self.image_size = (
|
||||
config.image_size
|
||||
if isinstance(self.config.image_size, tuple)
|
||||
else (self.config.image_size, self.config.image_size)
|
||||
)
|
||||
self.patch_size = config.tubelet_size
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.projection = nn.Conv3d(
|
||||
config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size
|
||||
)
|
||||
self.pos_emb_shape = [self.image_size[0] // self.patch_size[1], self.image_size[1] // self.patch_size[2]]
|
||||
self.num_patches = self.pos_emb_shape[0] * self.pos_emb_shape[1]
|
||||
|
||||
def forward(self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
||||
batch_size, num_frames, num_channels, height, width = pixel_values_videos.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
|
||||
raise ValueError(
|
||||
f"Image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]}). Set interpolate_pos_encoding=True to automatically resize the model position embeddings."
|
||||
)
|
||||
# permute to (batch_size, num_channels, num_frames, height, width)
|
||||
pixel_values_videos = pixel_values_videos.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.projection(pixel_values_videos)
|
||||
# flatten the spatial part and permute to (B, T, num_patches, dim)
|
||||
hidden_states = hidden_states.flatten(3).permute(0, 2, 3, 1)
|
||||
# combine batch and time dimension
|
||||
batch_size, num_frames, num_patches, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(batch_size * num_frames, num_patches, hidden_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismSpatialEmbeddings(nn.Module):
|
||||
"""
|
||||
VideoPrism Spatial Embeddings.
|
||||
|
||||
Creates embeddings from a video using VideoPrismSpatialTubeletEmbeddings and adds positional embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.patch_embeddings = VideoPrismTubeletEmbeddings(config)
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, self.patch_embeddings.num_patches, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.patch_size = config.tubelet_size[1:]
|
||||
self.tubelet_size = config.tubelet_size
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embeddings
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
num_row_patches = height // self.patch_size[0]
|
||||
num_col_patches = width // self.patch_size[1]
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = self.position_embeddings.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(num_row_patches, num_col_patches),
|
||||
mode="bilinear",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward(
|
||||
self, pixel_values_videos: torch.Tensor, interpolate_pos_encoding: bool | None = False
|
||||
) -> torch.Tensor:
|
||||
b, t, c, h, w = pixel_values_videos.shape
|
||||
assert h == w, "Input image height and width must be the same"
|
||||
embeddings = self.patch_embeddings(pixel_values_videos, interpolate_pos_encoding)
|
||||
|
||||
# add positional encoding to each token
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, h, w)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
|
||||
embeddings = self.dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class VideoPrismTemporalEmbeddings(nn.Module):
|
||||
"""
|
||||
VideoPrism Temporal Embeddings.
|
||||
|
||||
Receives embeddings from spatial encoder, reshapes the hidden state to
|
||||
(batch_size * num_patches, num_frames, hidden_size) and adds positional embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, self.config.num_frames, config.hidden_size))
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
target_emb_length = embeddings.shape[1]
|
||||
source_emb_length = self.position_embeddings.shape[1]
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and target_emb_length == source_emb_length:
|
||||
return self.position_embeddings
|
||||
|
||||
source_emb = self.position_embeddings
|
||||
dim = embeddings.shape[-1]
|
||||
source_emb = source_emb.unsqueeze(1)
|
||||
source_emb = nn.functional.interpolate(
|
||||
source_emb,
|
||||
size=(target_emb_length, dim),
|
||||
mode="bilinear",
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
return source_emb.squeeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.Tensor,
|
||||
input_shape: torch.Size,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
) -> torch.Tensor:
|
||||
if input_shape is not None:
|
||||
b, t, c, h, w = input_shape
|
||||
_, features, dim = pixel_values_videos.shape
|
||||
hidden_states = pixel_values_videos.view(b, t, features, dim)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
||||
embeddings = hidden_states.reshape(b * features, t, dim)
|
||||
|
||||
# add positional encoding to each token
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embeddings
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
softcap: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
|
||||
if softcap is not None:
|
||||
attn_weights = attn_weights / softcap
|
||||
attn_weights = torch.tanh(attn_weights)
|
||||
attn_weights = attn_weights * softcap
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask.expand(*attn_weights.shape)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class VideoPrismSelfAttention(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
|
||||
f"heads {config.num_attention_heads}."
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = config.attention_probs_dropout_prob
|
||||
self.scale = self.attention_head_size**-0.5
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
|
||||
query = self.query(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
key = self.key(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
value = self.value(hidden_states).view(*new_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
softcap=self.config.attn_logit_softcapping,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
|
||||
return (context_layer, attention_probs)
|
||||
|
||||
|
||||
class VideoPrismSelfOutput(nn.Module):
|
||||
"""
|
||||
The residual connection is defined in VideoPrismLayer instead of here (as is the case with other models), due to the
|
||||
layernorm applied before each block.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismAttention(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.attention = VideoPrismSelfAttention(config)
|
||||
self.output = VideoPrismSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
self_attn_output, _ = self.attention(hidden_states, attention_mask, **kwargs)
|
||||
output = self.output(self_attn_output, hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class VideoPrismLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight + 1, self.bias, self.eps)
|
||||
|
||||
|
||||
class VideoPrismIntermediate(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismOutput(nn.Module):
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VideoPrismLayer(GradientCheckpointingLayer):
|
||||
"""This corresponds to the EncoderBlock class in the scenic/videoprism implementation."""
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig | VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.attention = VideoPrismAttention(config)
|
||||
self.intermediate = VideoPrismIntermediate(config)
|
||||
self.output = VideoPrismOutput(config)
|
||||
self.layernorm_before = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.layernorm_after = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states_norm = self.layernorm_before(hidden_states)
|
||||
attention_output = self.attention(hidden_states_norm, attention_mask, **kwargs)
|
||||
|
||||
# first residual connection
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
# in VideoPrism, layernorm is also applied after self-attention
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
|
||||
# second residual connection is done here
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
|
||||
return layer_output
|
||||
|
||||
|
||||
class VideoPrismSpatialEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_spatial_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismTemporalEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_temporal_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismAuxiliaryEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(self.config) for _ in range(config.num_auxiliary_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
class VideoPrismTextEncoder(nn.Module):
|
||||
def __init__(self, config: VideoPrismTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([VideoPrismLayer(config) for _ in range(config.num_text_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(hidden_states, attention_mask, **kwargs)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == "fan_in":
|
||||
denom = fan_in
|
||||
elif mode == "fan_out":
|
||||
denom = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = 1.0 / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
init.trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||
elif distribution == "normal":
|
||||
init.normal_(tensor, std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
init.uniform_(tensor, -bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||
|
||||
|
||||
class VideoPrismPreTrainedModel(PreTrainedModel):
|
||||
config_class = VideoPrismConfig
|
||||
config: VideoPrismConfig
|
||||
base_model_prefix = "videoprism"
|
||||
main_input_name = "pixel_values_videos"
|
||||
input_modalities = ("video", "text")
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"VideoPrismSpatialEmbeddings",
|
||||
"VideoPrismTemporalEmbeddings",
|
||||
"VideoPrismSpatialEncoder",
|
||||
"VideoPrismTemporalEncoder",
|
||||
"VideoPrismAuxiliaryEncoder",
|
||||
"VideoPrismTextEncoder",
|
||||
"VideoPrismMultiheadAttentionPoolingHead",
|
||||
]
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn = True
|
||||
_supports_attention_backend = True
|
||||
_supports_flex_attention = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
||||
lecun_normal_(module.weight)
|
||||
init.zeros_(module.bias)
|
||||
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
init.zeros_(module.bias)
|
||||
init.ones_(module.weight)
|
||||
|
||||
|
||||
class VideoPrismVisionModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.layernorm1 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.layernorm2 = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.spatial_embeddings = VideoPrismSpatialEmbeddings(self.config)
|
||||
self.temporal_embeddings = VideoPrismTemporalEmbeddings(self.config)
|
||||
self.spatial_encoder = VideoPrismSpatialEncoder(self.config)
|
||||
self.temporal_encoder = VideoPrismTemporalEncoder(self.config)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> BaseModelOutputWithSpatialAndTemporalStates:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames of shape (batch_size, num_frames, num_channels, height, width).
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings to match input size.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVideoProcessor, VideoPrismVisionModel
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismVideoProcessor.from_pretrained("google/videoprism")
|
||||
>>> model = VideoPrismVisionModel.from_pretrained("google/videoprism")
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> inputs = processor(videos=video)
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... features = outputs.last_hidden_state
|
||||
```
|
||||
"""
|
||||
if pixel_values_videos is None:
|
||||
raise ValueError("You have to specify pixel_values_videos")
|
||||
|
||||
input_shape = pixel_values_videos.shape
|
||||
spatial_embeds = self.spatial_embeddings(pixel_values_videos, interpolate_pos_encoding)
|
||||
spatial_encoder_outputs: BaseModelOutput = self.spatial_encoder(hidden_states=spatial_embeds, **kwargs)
|
||||
# shape of spatial_sequence_output is (B * num_frames, num_patches, dim)
|
||||
spatial_sequence_output = spatial_encoder_outputs.last_hidden_state
|
||||
features = self.layernorm1(spatial_sequence_output)
|
||||
|
||||
temporal_embeds = self.temporal_embeddings(features, input_shape, interpolate_pos_encoding)
|
||||
temporal_encoder_outputs: BaseModelOutput = self.temporal_encoder(hidden_states=temporal_embeds, **kwargs)
|
||||
# shape of temporal_sequence_output is (B * num_patches, num_frames, dim)
|
||||
temporal_sequence_output = temporal_encoder_outputs.last_hidden_state
|
||||
features = self.layernorm2(temporal_sequence_output)
|
||||
_, num_frames, dim = features.shape
|
||||
features = features.view(input_shape[0], -1, num_frames, dim).permute(0, 2, 1, 3).contiguous()
|
||||
_, num_frames, num_patches, dim = features.shape
|
||||
features = features.view(input_shape[0], num_frames * num_patches, -1)
|
||||
|
||||
return BaseModelOutputWithSpatialAndTemporalStates(
|
||||
last_hidden_state=features,
|
||||
temporal_hidden_state=temporal_sequence_output,
|
||||
spatial_hidden_state=spatial_sequence_output,
|
||||
)
|
||||
|
||||
|
||||
class VideoPrismMultiheadAttentionPoolingHead(nn.Module):
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_attention_heads = self.config.num_attention_heads
|
||||
self.attention_head_size = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.dropout_prob = self.config.attention_probs_dropout_prob
|
||||
# PerDimScale
|
||||
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
self.per_dim_scale = nn.Parameter(torch.zeros(self.dim))
|
||||
r_softplus_0 = 1.442695041
|
||||
scale = torch.tensor(r_softplus_0 / (self.dim**0.5))
|
||||
softplus = nn.functional.softplus(self.per_dim_scale)
|
||||
scale = scale * softplus
|
||||
self.register_buffer("scale", scale)
|
||||
|
||||
self.pooling_attention_query = nn.Parameter(torch.zeros(1, 1, self.config.hidden_size))
|
||||
self.query = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.key = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.value = nn.Linear(self.config.hidden_size, self.config.intermediate_size, bias=self.config.qkv_bias)
|
||||
self.projection = nn.Linear(self.config.intermediate_size, self.config.hidden_size, bias=self.config.qkv_bias)
|
||||
self.layernorm = VideoPrismLayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps)
|
||||
self.dim = int(self.config.intermediate_size / self.config.num_attention_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
query = self.pooling_attention_query.expand(batch_size, -1, -1)
|
||||
query_layer = (
|
||||
self.query(query).view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
|
||||
)
|
||||
query_layer = query_layer * self.scale.expand(*query_layer.shape)
|
||||
|
||||
key_layer = (
|
||||
self.key(hidden_states)
|
||||
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_layer = (
|
||||
self.value(hidden_states)
|
||||
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
context_layer, attention_probs = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
scaling=1.0,
|
||||
dropout=0.0 if not self.training else self.dropout_prob,
|
||||
softcap=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||
context_layer = context_layer.reshape(new_context_layer_shape)
|
||||
outputs = self.projection(context_layer)
|
||||
outputs = self.layernorm(outputs)
|
||||
return (outputs, attention_probs)
|
||||
|
||||
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
"""This function is intended to align with the l2norm implementation in the FLA library."""
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return x * inv_norm
|
||||
|
||||
|
||||
class VideoPrismTextModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismTextConfig
|
||||
config: VideoPrismTextConfig
|
||||
|
||||
def __init__(self, config: VideoPrismTextConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.text_encoder = VideoPrismTextEncoder(self.config)
|
||||
self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.cls_emb = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
||||
self.layernorm = VideoPrismLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.normalize = config.apply_l2_norm
|
||||
self.post_init()
|
||||
|
||||
def create_sinusoidal_positions(self, num_pos: int, dim: int) -> torch.Tensor:
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / (dim - 2)))
|
||||
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
|
||||
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.Tensor`):
|
||||
Input token IDs.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask to avoid performing attention on padding token indices.
|
||||
"""
|
||||
batch_size, seq_length = input_ids.shape
|
||||
hidden_states = self.token_embeddings(input_ids)
|
||||
hidden_states = hidden_states * (self.config.hidden_size**0.5)
|
||||
|
||||
cls_padding = torch.ones(batch_size, 1)
|
||||
input_ids = torch.cat((input_ids, cls_padding), dim=1)
|
||||
attention_mask = torch.cat((attention_mask, cls_padding), dim=1) if attention_mask is not None else None
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=torch.arange(hidden_states.shape[1] + 1, device=hidden_states.device),
|
||||
past_key_values=None,
|
||||
)
|
||||
|
||||
features = hidden_states + self.create_sinusoidal_positions(seq_length, self.config.hidden_size)
|
||||
cls_emb = self.cls_emb * (self.config.hidden_size**0.5)
|
||||
cls_emb = cls_emb.expand(features.shape[0], -1, -1)
|
||||
features = torch.cat((features, cls_emb), dim=1)
|
||||
text_encoder_output = self.text_encoder(features, attention_mask)
|
||||
features = text_encoder_output.last_hidden_state
|
||||
features = self.layernorm(features)
|
||||
text_embeddings = features[:, -1]
|
||||
|
||||
if self.normalize:
|
||||
text_embeddings = l2norm(text_embeddings, dim=-1)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=text_embeddings,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class VideoPrismVideoModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.backbone = VideoPrismVisionModel(self.config)
|
||||
self.auxiliary_encoder = VideoPrismAuxiliaryEncoder(self.config)
|
||||
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
|
||||
self.normalize = self.config.apply_l2_norm
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.backbone.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> VideoPrismVideoOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings to match input size.
|
||||
"""
|
||||
backbone_outputs = self.backbone(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
video_features = backbone_outputs.last_hidden_state
|
||||
auxiliary_output = self.auxiliary_encoder(video_features)
|
||||
auxiliary_output_features = auxiliary_output.last_hidden_state
|
||||
contrastive_vision_pooler_output = self.contrastive_vision_pooler(auxiliary_output_features, **kwargs)
|
||||
video_embeddings = contrastive_vision_pooler_output[0]
|
||||
if self.normalize:
|
||||
video_embeddings = l2norm(video_embeddings, dim=-1)
|
||||
|
||||
return VideoPrismVideoOutput(
|
||||
video_last_hidden_state=video_embeddings,
|
||||
auxiliary_output=auxiliary_output,
|
||||
attention_pooling_output=contrastive_vision_pooler_output,
|
||||
)
|
||||
|
||||
|
||||
class VideoPrismClipModel(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismConfig
|
||||
|
||||
def __init__(self, config: VideoPrismConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.vision_config = config.vision_config
|
||||
self.text_config = config.text_config
|
||||
self.video_model = VideoPrismVideoModel(self.vision_config)
|
||||
self.text_model = VideoPrismTextModel(self.text_config)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
temperature: float | None = None,
|
||||
**kwargs,
|
||||
) -> VideoPrismClipOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
input_ids (`torch.Tensor`):
|
||||
Input token IDs for text.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for text inputs.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings.
|
||||
temperature (`float`, *optional*):
|
||||
Temperature parameter for scaling similarity scores.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismProcessor, VideoPrismClipModel
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismProcessor.from_pretrained("google/videoprism")
|
||||
>>> model = VideoPrismClipModel.from_pretrained("google/videoprism")
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> texts = ["a dog", "a cat"]
|
||||
>>> inputs = processor(videos=video, texts=texts, return_tensors="pt", padding=True)
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... logits_per_video = outputs.logits_per_video
|
||||
```
|
||||
"""
|
||||
video_model_outputs = self.video_model(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
text_model_outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
|
||||
|
||||
video_embeddings = video_model_outputs.video_last_hidden_state
|
||||
text_embeddings = text_model_outputs.last_hidden_state
|
||||
emb_dim = video_embeddings[0].shape[-1]
|
||||
assert emb_dim == text_embeddings[0].shape[-1]
|
||||
|
||||
video_embeds = video_embeddings.reshape(-1, emb_dim)
|
||||
text_embeds = text_embeddings.reshape(-1, emb_dim)
|
||||
similarity_matrix = torch.matmul(video_embeds, text_embeds.T)
|
||||
|
||||
if temperature is not None:
|
||||
similarity_matrix /= temperature
|
||||
|
||||
logits_per_video = torch.exp(similarity_matrix)
|
||||
logits_per_text = logits_per_video.T
|
||||
logits_per_video = logits_per_video / torch.sum(logits_per_video, dim=0, keepdims=True)
|
||||
logits_per_text = logits_per_text / torch.sum(logits_per_text, dim=0, keepdims=True)
|
||||
|
||||
return VideoPrismClipOutput(
|
||||
logits_per_video=logits_per_video,
|
||||
logits_per_text=logits_per_text,
|
||||
video_embeds=video_embeds,
|
||||
text_embeds=text_embeds,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class VideoPrismForVideoClassification(VideoPrismPreTrainedModel):
|
||||
config_class = VideoPrismVisionConfig
|
||||
config: VideoPrismVisionConfig
|
||||
|
||||
def __init__(self, config: VideoPrismVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.encoder = VideoPrismVisionModel(self.config)
|
||||
self.contrastive_vision_pooler = VideoPrismMultiheadAttentionPoolingHead(self.config)
|
||||
self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.encoder.spatial_embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values_videos: torch.FloatTensor,
|
||||
labels: torch.LongTensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs,
|
||||
) -> ImageClassifierOutput:
|
||||
r"""
|
||||
Args:
|
||||
pixel_values_videos (`torch.FloatTensor`):
|
||||
Pixel values of the video frames.
|
||||
labels (`torch.LongTensor`, *optional*):
|
||||
Video classification labels.
|
||||
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
|
||||
Whether to interpolate positional encodings.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import VideoPrismVideoProcessor, VideoPrismForVideoClassification
|
||||
>>> import torch
|
||||
|
||||
>>> processor = VideoPrismVideoProcessor("google/videoprism")
|
||||
>>> model = VideoPrismForVideoClassification.from_pretrained("google/videoprism", num_labels=1000)
|
||||
|
||||
>>> video = "sample_video.mp4"
|
||||
>>> inputs = processor(videos=video, return_tensors="pt")
|
||||
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
... logits = outputs.logits
|
||||
```
|
||||
"""
|
||||
encoder_outputs = self.encoder(
|
||||
pixel_values_videos=pixel_values_videos, interpolate_pos_encoding=interpolate_pos_encoding, **kwargs
|
||||
)
|
||||
sequence_output = encoder_outputs.last_hidden_state
|
||||
pooled_output = self.contrastive_vision_pooler(sequence_output, **kwargs).pooled_output
|
||||
logits = self.classifier(pooled_output)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.loss_function(labels, logits, self.config, **kwargs)
|
||||
|
||||
return ImageClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=encoder_outputs.last_hidden_state,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"VideoPrismVisionModel",
|
||||
"VideoPrismPreTrainedModel",
|
||||
"VideoPrismVideoModel",
|
||||
"VideoPrismTextModel",
|
||||
"VideoPrismClipModel",
|
||||
"VideoPrismForVideoClassification",
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
from lerobot.policies.videovla.videoprism import VideoPrismVideoProcessor
|
||||
from lerobot.policies.videovla.videoprism import VideoPrismVisionModel
|
||||
processor = VideoPrismVideoProcessor.from_pretrained(
|
||||
"MHRDYN7/videoprism-base-f16r288"
|
||||
)
|
||||
|
||||
model = VideoPrismVisionModel.from_pretrained(
|
||||
"MHRDYN7/videoprism-base-f16r288",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
|
||||
video_url = "https://huggingface.co/datasets/nateraw/kinetics-mini/resolve/main/val/archery/-Qz25rXdMjE_000014_000024.mp4"
|
||||
|
||||
vr = VideoDecoder(video_url)
|
||||
frame_idx = np.arange(0, 64)
|
||||
video = vr.get_frames_at(indices=frame_idx).data # T x C x H x W
|
||||
|
||||
video = processor(video, return_tensors="pt")
|
||||
video = {k: v.to(model.device, model.dtype) for k, v in video.items()}
|
||||
outputs = model(**video)
|
||||
encoder_outputs = outputs.last_hidden_state
|
||||
print(encoder_outputs.shape) #
|
||||
|
||||
import time
|
||||
import torch
|
||||
|
||||
# warmup
|
||||
for _ in range(10):
|
||||
_ = model(**video)
|
||||
|
||||
times = []
|
||||
for _ in range(50):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
|
||||
_ = model(**video)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.perf_counter()
|
||||
times.append(t1 - t0)
|
||||
|
||||
print(f"Mean: {1000*sum(times)/len(times):.2f} ms")
|
||||
print(f"Min : {1000*min(times):.2f} ms")
|
||||
print(f"Max : {1000*max(times):.2f} ms")
|
||||
@@ -1,44 +0,0 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/videoprism/modular_videoprism.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_videoprism.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
|
||||
|
||||
class VideoPrismVideoProcessor(BaseVideoProcessor):
|
||||
r"""
|
||||
Constructs a VideoPrism video processor.
|
||||
|
||||
This processor inherits from [`LlavaOnevisionVideoProcessor`] and sets default parameters for VideoPrism models.
|
||||
Video frames are resized to 288x288 using bicubic resampling without normalization.
|
||||
|
||||
Args:
|
||||
size (`Dict[str, int]`, *optional*, defaults to `{"height": 288, "width": 288}`):
|
||||
The size to resize the video frames to.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
The resampling filter to use when resizing images.
|
||||
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether to normalize the video frames.
|
||||
"""
|
||||
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
|
||||
size = {"height": 288, "width": 288}
|
||||
rescale_factor = 1 / 255
|
||||
default_to_square = False
|
||||
crop_size = None
|
||||
do_resize = True
|
||||
do_center_crop = None
|
||||
do_rescale = True
|
||||
do_normalize = False
|
||||
do_convert_rgb = True
|
||||
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
|
||||
|
||||
|
||||
__all__ = ["VideoPrismVideoProcessor"]
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .audio_processor import AudioProcessorStep
|
||||
from .batch_processor import AddBatchDimensionProcessorStep
|
||||
from .converters import (
|
||||
batch_to_transition,
|
||||
@@ -80,6 +81,7 @@ __all__ = [
|
||||
"ActionProcessorStep",
|
||||
"AddTeleopActionAsComplimentaryDataStep",
|
||||
"AddTeleopEventsAsInfoStep",
|
||||
"AudioProcessorStep",
|
||||
"ComplementaryDataProcessorStep",
|
||||
"batch_to_transition",
|
||||
"create_transition",
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from torch import Tensor
|
||||
from torchaudio.functional import amplitude_to_DB
|
||||
from torchaudio.transforms import MelSpectrogram, Resample
|
||||
from torchvision.transforms import Compose, Lambda, Resize
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.utils.constants import OBS_AUDIO
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="audio_processor")
|
||||
class AudioProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes audio waveform data into a mel-spectrogram image representation.
|
||||
|
||||
**Audio Processing:**
|
||||
- Averages waveform data over all channels.
|
||||
- Resamples the waveform to 16kHz.
|
||||
- Converts the waveform to a mel-spectrogram.
|
||||
- Converts the mel-spectrogram to decibels.
|
||||
- Resizes the mel-spectrogram to 224×224.
|
||||
- Converts the mel-spectrogram to a channel-first, normalized tensor.
|
||||
|
||||
Attributes:
|
||||
output_height: Height of the output mel-spectrogram image in pixels.
|
||||
output_width: Width of the output mel-spectrogram image in pixels.
|
||||
output_channels: Number of channels in the output image (3 for RGB-like format).
|
||||
input_audio_chunk_duration: Duration of the input audio chunk in seconds.
|
||||
input_sample_rate: Original sample rate of the input audio in Hz.
|
||||
|
||||
intermediate_sample_rate: Reduced intermediate sample rate in Hz.
|
||||
Downsampling improves the temporal resolution but reduces the frequency range.
|
||||
n_fft: Size of the FFT window for spectrogram computation.
|
||||
Increasing the window size increases the frequency resolution but decreases the temporal resolution.
|
||||
|
||||
hop_length: Number of samples between successive frames, computed automatically to match the output_width.
|
||||
Decreasing the hop length increases the temporal resolution but decreases the frequency resolution.
|
||||
n_mels: Number of mel filter banks, computed automatically to match the output_height.
|
||||
Increasing the number of banks increases the number of rows in the spectrogram and the frequency resolution.
|
||||
mel_spectrogram_transform: The complete audio processing pipeline.
|
||||
"""
|
||||
|
||||
output_height: int = 224
|
||||
output_width: int = 224
|
||||
output_channels: int = 3
|
||||
input_audio_chunk_duration: float = DEFAULT_AUDIO_CHUNK_DURATION
|
||||
|
||||
input_sample_rate: int = 48000
|
||||
intermediate_sample_rate: int = 16000
|
||||
|
||||
n_fft: int = 1024
|
||||
|
||||
# Parameters computed from other parameters at initialization
|
||||
hop_length: int = field(init=False)
|
||||
n_mels: int = field(init=False)
|
||||
mel_spectrogram_transform: Compose = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.hop_length = int(
|
||||
self.intermediate_sample_rate * self.input_audio_chunk_duration
|
||||
- self.n_fft // self.output_width
|
||||
- 1
|
||||
)
|
||||
self.n_mels = self.output_height
|
||||
|
||||
self.mel_spectrogram_transform = Compose(
|
||||
[
|
||||
Lambda(lambda x: x.mean(dim=1)), # Average over all channels (second dimension after batch)
|
||||
Resample(orig_freq=self.input_sample_rate, new_freq=self.intermediate_sample_rate),
|
||||
MelSpectrogram(
|
||||
sample_rate=self.intermediate_sample_rate,
|
||||
n_fft=self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
n_mels=self.n_mels,
|
||||
power=2, # Power spectrum
|
||||
),
|
||||
Lambda(
|
||||
lambda x: amplitude_to_DB(x, multiplier=10, amin=1e-10, db_multiplier=0)
|
||||
), # Convert to decibels
|
||||
Resize(
|
||||
(self.output_height, self.output_width)
|
||||
), # Resize spectrogram to output_height×output_width
|
||||
Lambda(
|
||||
lambda x: x.unsqueeze(1).expand(-1, self.output_channels, -1, -1)
|
||||
), # Duplicate across 3 channels to mimic RGB images. Dimensions are [batch, rgb, height, width].
|
||||
]
|
||||
)
|
||||
|
||||
def _process_observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""
|
||||
Processes audio data contained in the provided observation.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
|
||||
# Process single audio observation
|
||||
if OBS_AUDIO in processed_obs:
|
||||
audio_data = processed_obs[OBS_AUDIO]
|
||||
if isinstance(audio_data, Tensor) and audio_data.dim() == 3: # Batch, Channels, Samples
|
||||
processed_obs[OBS_AUDIO] = self.mel_spectrogram_transform(audio_data)
|
||||
|
||||
# Process multiple audio observations
|
||||
for key, value in processed_obs.items():
|
||||
if (
|
||||
key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 3
|
||||
): # Batch, Channels, Samples
|
||||
processed_obs[key] = self.mel_spectrogram_transform(value)
|
||||
|
||||
return processed_obs
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
return self._process_observation(observation)
|
||||
@@ -25,7 +25,7 @@ from dataclasses import dataclass, field
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .core import EnvTransition, PolicyAction
|
||||
from .pipeline import (
|
||||
@@ -88,6 +88,8 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
- State vectors (1D tensors).
|
||||
- Single images (3D tensors).
|
||||
- Dictionaries of multiple images (3D tensors).
|
||||
- Single audio waveforms (2D tensors).
|
||||
- Dictionaries of multiple audio waveforms (2D tensors).
|
||||
"""
|
||||
|
||||
def observation(self, observation: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
@@ -117,6 +119,18 @@ class AddBatchDimensionObservationStep(ObservationProcessorStep):
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}.") and isinstance(value, Tensor) and value.dim() == 3:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
# Process single audio observation - add batch dim if 2D
|
||||
if OBS_AUDIO in observation:
|
||||
audio_value = observation[OBS_AUDIO]
|
||||
if isinstance(audio_value, Tensor) and audio_value.dim() == 2:
|
||||
observation[OBS_AUDIO] = audio_value.unsqueeze(0)
|
||||
|
||||
# Process multiple audio observations - add batch dim if 2D
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_AUDIO}.") and isinstance(value, Tensor) and value.dim() == 2:
|
||||
observation[key] = value.unsqueeze(0)
|
||||
|
||||
return observation
|
||||
|
||||
def transform_features(
|
||||
|
||||
@@ -34,6 +34,13 @@ class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the camera to be used in a robot"
|
||||
)
|
||||
if hasattr(self, "microphones") and self.microphones:
|
||||
for _, config in self.microphones.items():
|
||||
for attr in ["sample_rate", "channels"]:
|
||||
if getattr(config, attr) is None:
|
||||
raise ValueError(
|
||||
f"Specifying '{attr}' is required for the microphone to be used in a robot"
|
||||
)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -31,6 +32,8 @@ class HopeJrHandConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.side not in ["right", "left"]:
|
||||
@@ -49,3 +52,5 @@ class HopeJrArmConfig(RobotConfig):
|
||||
max_relative_target: float | dict[str, float] | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -144,6 +144,13 @@ class HopeJrArm(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -175,6 +175,13 @@ class HopeJrHand(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -35,5 +36,8 @@ class KochFollowerConfig(RobotConfig):
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -19,6 +19,7 @@ import time
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
@@ -61,6 +62,7 @@ class KochFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -72,9 +74,16 @@ class KochFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -82,7 +91,11 @@ class KochFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -101,6 +114,9 @@ class KochFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -197,6 +213,13 @@ class KochFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -232,5 +255,7 @@ class KochFollower(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -16,6 +16,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras.configs import CameraConfig, Cv2Rotation
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -45,6 +46,8 @@ class LeKiwiConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
@@ -92,5 +95,7 @@ class LeKiwiClientConfig(RobotConfig):
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config)
|
||||
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
polling_timeout_ms: int = 15
|
||||
connect_timeout_s: int = 5
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
@@ -73,6 +74,7 @@ class LeKiwi(Robot):
|
||||
self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")]
|
||||
self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")]
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _state_ft(self) -> dict[str, type]:
|
||||
@@ -97,9 +99,16 @@ class LeKiwi(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -107,7 +116,11 @@ class LeKiwi(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -121,6 +134,9 @@ class LeKiwi(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -364,6 +380,13 @@ class LeKiwi(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -413,5 +436,7 @@ class LeKiwi(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -18,6 +18,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from time import perf_counter
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -58,8 +59,9 @@ class LeKiwiClient(Robot):
|
||||
self.zmq_observation_socket = None
|
||||
|
||||
self.last_frames = {}
|
||||
|
||||
self.last_remote_state = {}
|
||||
self.last_frame_timestamp = None
|
||||
self.last_frame_delay = 0.0
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
@@ -97,9 +99,13 @@ class LeKiwiClient(Robot):
|
||||
def _cameras_ft(self) -> dict[str, tuple[int, int, int]]:
|
||||
return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()}
|
||||
|
||||
@cached_property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {name: (cfg.sample_rate, cfg.channels) for name, cfg in self.config.microphones.items()}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._state_ft, **self._cameras_ft}
|
||||
return {**self._state_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -135,6 +141,7 @@ class LeKiwiClient(Robot):
|
||||
if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN:
|
||||
raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.")
|
||||
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
self._is_connected = True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
@@ -167,6 +174,8 @@ class LeKiwiClient(Robot):
|
||||
if last_msg is None:
|
||||
logging.warning("Poller indicated data, but failed to retrieve message.")
|
||||
|
||||
self.last_frame_delay = perf_counter() - self.last_frame_timestamp
|
||||
self.last_frame_timestamp = perf_counter()
|
||||
return last_msg
|
||||
|
||||
def _parse_observation_json(self, obs_string: str) -> RobotObservation | None:
|
||||
@@ -203,14 +212,16 @@ class LeKiwiClient(Robot):
|
||||
|
||||
obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec}
|
||||
|
||||
# Decode images
|
||||
# Decode images and audio data
|
||||
current_frames: dict[str, np.ndarray] = {}
|
||||
for cam_name, image_b64 in observation.items():
|
||||
if cam_name not in self._cameras_ft:
|
||||
continue
|
||||
frame = self._decode_image_from_b64(image_b64)
|
||||
if frame is not None:
|
||||
current_frames[cam_name] = frame
|
||||
for frame_name, frame_data in observation.items():
|
||||
if frame_name in self._cameras_ft:
|
||||
image = self._decode_image_from_b64(frame_data)
|
||||
if image is not None:
|
||||
current_frames[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
if frame_data is not None:
|
||||
current_frames[frame_name] = frame_data
|
||||
|
||||
return current_frames, obs_dict
|
||||
|
||||
@@ -254,17 +265,27 @@ class LeKiwiClient(Robot):
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
and cameras and microphones data. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, frame in frames.items():
|
||||
if frame is None:
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = frame
|
||||
# Loop over each configured camera and microphone
|
||||
for frame_name, frame_data in frames.items():
|
||||
if frame_data is None:
|
||||
if frame_name in self._cameras_ft:
|
||||
logging.warning("Image frame is None")
|
||||
image = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[frame_name] = image
|
||||
elif frame_name in self._microphones_ft:
|
||||
logging.warning("Audio frame is None")
|
||||
obs_dict[frame_name] = np.zeros(
|
||||
(
|
||||
int(self._microphones_ft[frame_name][0] * self.last_frame_delay),
|
||||
self._microphones_ft[frame_name][1],
|
||||
),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
return obs_dict
|
||||
|
||||
|
||||
@@ -58,32 +58,6 @@ class Robot(abc.ABC):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -18,6 +18,7 @@ from dataclasses import dataclass, field
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
from lerobot.microphones import MicrophoneConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
@@ -39,6 +40,9 @@ class SOFollowerConfig:
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
# microphones
|
||||
microphones: dict[str, MicrophoneConfig] = field(default_factory=dict)
|
||||
|
||||
# Set to `True` for backward compatibility with previous policies/dataset
|
||||
use_degrees: bool = False
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from functools import cached_property
|
||||
from typing import TypeAlias
|
||||
|
||||
from lerobot.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.microphones.utils import make_microphones_from_configs
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
@@ -62,6 +63,7 @@ class SOFollower(Robot):
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
self.microphones = make_microphones_from_configs(config.microphones)
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
@@ -73,9 +75,16 @@ class SOFollower(Robot):
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@property
|
||||
def _microphones_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
mic: (self.config.microphones[mic].sample_rate, self.config.microphones[mic].channels)
|
||||
for mic in self.microphones
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
return {**self._motors_ft, **self._cameras_ft, **self._microphones_ft}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
@@ -83,7 +92,11 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
return (
|
||||
self.bus.is_connected
|
||||
and all(cam.is_connected for cam in self.cameras.values())
|
||||
and all(mic.is_connected for mic in self.microphones.values())
|
||||
)
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
@@ -102,6 +115,9 @@ class SOFollower(Robot):
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
for mic in self.microphones.values():
|
||||
mic.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@@ -191,6 +207,13 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Read audio frames from microphones
|
||||
for mic_key, mic in self.microphones.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[mic_key] = mic.read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {mic_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -226,6 +249,8 @@ class SOFollower(Robot):
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
for mic in self.microphones.values():
|
||||
mic.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ import draccus
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
|
||||
@@ -64,11 +64,14 @@ lerobot-record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
)
|
||||
@@ -81,8 +84,23 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_INITIAL_AUDIO_BUFFER_DURATION,
|
||||
build_dataset_frame,
|
||||
combine_feature_dicts,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.microphones import (
|
||||
MicrophoneConfig, # noqa: F401
|
||||
)
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.microphones.utils import (
|
||||
async_microphones_start_recording,
|
||||
async_microphones_stop_recording,
|
||||
)
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
@@ -120,6 +138,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.audio_utils import rolling_vstack
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
@@ -279,6 +298,13 @@ def record_loop(
|
||||
display_data: bool = False,
|
||||
display_compressed_images: bool = False,
|
||||
):
|
||||
if display_data:
|
||||
init_rerun(
|
||||
session_name="recording",
|
||||
robot=robot,
|
||||
reset_time=True,
|
||||
)
|
||||
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
@@ -313,6 +339,36 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
# Create a buffer for audio observations (shifting window of fixed size over audio samples)
|
||||
if robot.microphones and (policy is not None or dataset is not None):
|
||||
audio_buffer = {
|
||||
microphone_name: np.zeros(
|
||||
(int(microphone.sample_rate * DEFAULT_AUDIO_CHUNK_DURATION), len(microphone.channels))
|
||||
)
|
||||
for microphone_name, microphone in robot.microphones.items()
|
||||
}
|
||||
|
||||
if (
|
||||
dataset is not None and robot.name != "lekiwi"
|
||||
): # For now, LeKiwi only supports frame audio recording (which may lead to audio chunks loss, extended post-processing, increased memory usage)
|
||||
dataset.add_microphones_recordings(robot.microphones)
|
||||
else:
|
||||
async_microphones_start_recording(robot.microphones)
|
||||
|
||||
# Fill audio buffers if needed
|
||||
if (
|
||||
robot.microphones
|
||||
and (policy is not None or dataset is not None)
|
||||
and DEFAULT_INITIAL_AUDIO_BUFFER_DURATION > 0.0
|
||||
):
|
||||
# This initial wait might be longer than the audio chunk duration to
|
||||
# (1) ensure that the audio buffers are filled with enough data
|
||||
# (2) add additional initial samples to the dataset in case of variable audio chunk duration during training
|
||||
precise_sleep(DEFAULT_INITIAL_AUDIO_BUFFER_DURATION)
|
||||
for microphone_name, microphone in robot.microphones.items():
|
||||
audio_chunk = microphone.read()
|
||||
audio_buffer[microphone_name] = rolling_vstack(audio_buffer[microphone_name], audio_chunk)
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -333,8 +389,14 @@ def record_loop(
|
||||
|
||||
# Get action from either policy or teleop
|
||||
if policy is not None and preprocessor is not None and postprocessor is not None:
|
||||
# Transform instantaneous audio samples into a buffer of fixed size
|
||||
buffered_observation_frame = copy(observation_frame)
|
||||
for name in audio_buffer:
|
||||
# Add the audio buffer to the observation
|
||||
buffered_observation_frame[name] = rolling_vstack(audio_buffer[name], observation_frame[name])
|
||||
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
observation=buffered_observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
@@ -389,14 +451,26 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
observation=obs_processed,
|
||||
action=action_values,
|
||||
compress_images=display_compressed_images,
|
||||
log_time=time.perf_counter() - start_episode_t,
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(max(1 / fps - dt_s, 0.0))
|
||||
remaining_time = 1 / fps - dt_s
|
||||
if remaining_time > 0.0:
|
||||
print(f"Waiting {remaining_time:.2f} seconds to maintain {fps:.2f} Hz control loop frequency.")
|
||||
precise_sleep(remaining_time)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Inconsistent control loop frequency: {1 / dt_s:.2f} Hz < {fps:.2f} Hz. Try reducing the cameras resolution or FPS."
|
||||
)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
async_microphones_stop_recording(robot.microphones)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
@@ -61,6 +61,9 @@ import rerun as rr
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.microphones.anyskin.configuration_anyskin import AnyskinSensorConfig # noqa: F401
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig # noqa: F401
|
||||
from lerobot.microphones.touchlab.configuration_touchlab import TouchLabSensorConfig # noqa: F401
|
||||
from lerobot.processor import (
|
||||
RobotAction,
|
||||
RobotObservation,
|
||||
@@ -143,8 +146,18 @@ def teleop_loop(
|
||||
robot_action_processor: An optional pipeline to process actions before they are sent to the robot.
|
||||
robot_observation_processor: An optional pipeline to process raw observations from the robot.
|
||||
"""
|
||||
if display_data:
|
||||
init_rerun(
|
||||
session_name="teleoperation",
|
||||
robot=robot,
|
||||
reset_time=True,
|
||||
)
|
||||
|
||||
display_len = max(len(key) for key in robot.action_features)
|
||||
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.start_recording()
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
while True:
|
||||
@@ -176,6 +189,7 @@ def teleop_loop(
|
||||
observation=obs_transition,
|
||||
action=teleop_action,
|
||||
compress_images=display_compressed_images,
|
||||
log_time=time.perf_counter() - start,
|
||||
)
|
||||
|
||||
print("\n" + "-" * (display_len + 10))
|
||||
@@ -192,7 +206,10 @@ def teleop_loop(
|
||||
move_cursor_up(1)
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
break
|
||||
|
||||
for _, microphone in robot.microphones.items():
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -58,32 +58,6 @@ class Teleoperator(abc.ABC):
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Context manager entry.
|
||||
Automatically connects to the camera.
|
||||
"""
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""
|
||||
Context manager exit.
|
||||
Automatically disconnects, ensuring resources are released even on error.
|
||||
"""
|
||||
self.disconnect()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Destructor safety net.
|
||||
Attempts to disconnect if the object is garbage collected without cleanup.
|
||||
"""
|
||||
try:
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
except Exception: # nosec B110
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_features(self) -> dict:
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def rolling_vstack(buffer: np.ndarray, new_data: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Rolling implementation of numpy.vstack to add new data in at the end of a fixed shape buffer in a rolling fashion.
|
||||
|
||||
Args:
|
||||
buffer: The *fixed* shape buffer to update.
|
||||
new_data: The new data to add to the buffer.
|
||||
|
||||
Returns:
|
||||
The updated buffer.
|
||||
"""
|
||||
|
||||
buffer_size = buffer.shape[0]
|
||||
# Remove as many old audio samples as needed
|
||||
buffer[: -len(new_data)] = buffer[len(new_data) :]
|
||||
# Add new audio samples, only the newest if the buffer is already full
|
||||
buffer[-len(new_data) :] = new_data[-buffer_size:]
|
||||
return buffer
|
||||
@@ -23,6 +23,7 @@ OBS_ENV_STATE = OBS_STR + ".environment_state"
|
||||
OBS_STATE = OBS_STR + ".state"
|
||||
OBS_IMAGE = OBS_STR + ".image"
|
||||
OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_AUDIO = OBS_STR + ".audio"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
|
||||
@@ -102,7 +102,7 @@ def predict_action(
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
# Convert to pytorch format: normalizing and permuting (channel first)
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
observation = preprocessor(observation)
|
||||
|
||||
|
||||
@@ -30,3 +30,22 @@ class DeviceAlreadyConnectedError(ConnectionError):
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceNotRecordingError(Exception):
|
||||
"""Exception raised when the robot device is not recording."""
|
||||
|
||||
def __init__(self, message="This robot device is not recording. Try calling `start_recording()` first."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceAlreadyRecordingError(Exception):
|
||||
"""Exception raised when the robot device is already recording."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This robot device is already recording. Try not calling `start_recording()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from multiprocessing import Lock, Value, shared_memory
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class SharedArray:
|
||||
"""
|
||||
A SharedArray is a numpy array shared between multiple processes in a shared_memory object.
|
||||
- Data is written to the array using the `write` method, which appends data to the array.
|
||||
- Data is read from the array (and eventually flushed) using the `read` method, which copies the _whole_ array.
|
||||
|
||||
SharedArray offers quasi-instantaneous array-wide read and flush capabilities in comparison to Queues, but has a limited size defined at initialization.
|
||||
|
||||
Example:
|
||||
_Main_process_
|
||||
shared_array = SharedArray(shape=(10, 10), dtype=np.dtype("float32"))
|
||||
local_array = shared_array.get_local_array()
|
||||
shared_array.write(local_array, np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
|
||||
_Child_process_
|
||||
local_array = shared_array.get_local_array()
|
||||
data = shared_array.read(local_array, flush=True)
|
||||
"""
|
||||
|
||||
def __init__(self, shape: tuple[int], dtype: np.dtype | str):
|
||||
"""
|
||||
Initialize a SharedArray.
|
||||
|
||||
Args:
|
||||
shape: The shape of the shared array.
|
||||
dtype: The dtype of the shared array.
|
||||
"""
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=np.prod(shape) * np.dtype(dtype).itemsize
|
||||
)
|
||||
self.read_index = Value("i", 0)
|
||||
self.lock = Lock()
|
||||
|
||||
def get_local_array(self) -> np.ndarray:
|
||||
"""
|
||||
Get a process-local instance of the shared array.
|
||||
|
||||
Returns:
|
||||
A process-local instance of the shared array.
|
||||
"""
|
||||
return np.ndarray(self.shape, dtype=np.dtype(self.dtype), buffer=self.shared_memory.buf)
|
||||
|
||||
def delete(self):
|
||||
"""
|
||||
Delete the shared array.
|
||||
"""
|
||||
self.shared_memory.close()
|
||||
self.shared_memory.unlink()
|
||||
|
||||
def write(self, local_array: np.ndarray, data: np.ndarray):
|
||||
"""
|
||||
Write data to the shared array.
|
||||
|
||||
Args:
|
||||
local_array: The process-local instance of the shared array to write to.
|
||||
data: The data to write to the shared array.
|
||||
"""
|
||||
with self.lock:
|
||||
local_array[self.read_index.value : self.read_index.value + len(data)] = data
|
||||
self.read_index.value += len(data)
|
||||
|
||||
def read(self, local_array: np.ndarray, flush: bool = True) -> np.ndarray:
|
||||
"""
|
||||
Read data from the shared array.
|
||||
|
||||
Args:
|
||||
local_array: The process-local instance of the shared array to read from.
|
||||
flush: Whether to flush the shared array after reading.
|
||||
"""
|
||||
with self.lock:
|
||||
data = np.copy(local_array[: self.read_index.value])
|
||||
if flush:
|
||||
self.read_index.value = 0
|
||||
return data
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the read index to 0.
|
||||
"""
|
||||
with self.lock:
|
||||
self.read_index.value = 0
|
||||
@@ -14,17 +14,25 @@
|
||||
|
||||
import numbers
|
||||
import os
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_AUDIO_CHUNK_DURATION
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots import Robot
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(
|
||||
session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None
|
||||
session_name: str = "lerobot_control_loop",
|
||||
ip: str | None = None,
|
||||
port: int | None = None,
|
||||
robot: Robot | None = None,
|
||||
reset_time: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Rerun SDK for visualizing the control loop.
|
||||
@@ -33,16 +41,26 @@ def init_rerun(
|
||||
session_name: Name of the Rerun session.
|
||||
ip: Optional IP for connecting to a Rerun server.
|
||||
port: Optional port for connecting to a Rerun server.
|
||||
robot: A Robot object. If provided, Rerun will be initialized with a blueprint that includes the object's cameras and microphones.
|
||||
reset_time: Whether to reset the timer "episode_time" to 0.
|
||||
"""
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
rr.init(session_name)
|
||||
rr.init(
|
||||
application_id=session_name,
|
||||
recording_id=uuid4(),
|
||||
)
|
||||
if robot is not None:
|
||||
rr.send_blueprint(build_rerun_blueprint(robot))
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
if ip and port:
|
||||
rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy")
|
||||
else:
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
if reset_time:
|
||||
rr.set_time("episode_time", timestamp=0.0)
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
||||
@@ -50,10 +68,47 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def build_rerun_blueprint(robot: Robot) -> rr.blueprint.Grid:
|
||||
""" "
|
||||
Builds a Rerun blueprint for optimized visualization of the robot's observations and actions :
|
||||
- Time series views for all scalar observations and actions (e.g. position, velocity, torque, etc.).
|
||||
- Spatial 2D views for all camera observations.
|
||||
- Time series views for all microphone observations.
|
||||
|
||||
Args:
|
||||
robot: A Robot object.
|
||||
Returns:
|
||||
A Rerun blueprint.
|
||||
"""
|
||||
contents = [
|
||||
rr.blueprint.TimeSeriesView(
|
||||
origin="data",
|
||||
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||
)
|
||||
]
|
||||
if robot.microphones:
|
||||
contents += [
|
||||
rr.blueprint.TimeSeriesView(
|
||||
origin="audio",
|
||||
plot_legend=rr.blueprint.PlotLegend(visible=True),
|
||||
)
|
||||
]
|
||||
if robot.cameras:
|
||||
contents += [
|
||||
rr.blueprint.Spatial2DView(
|
||||
origin=OBS_PREFIX + camera_name,
|
||||
)
|
||||
for camera_name in robot.cameras
|
||||
]
|
||||
|
||||
return rr.blueprint.Grid(*contents)
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
compress_images: bool = False,
|
||||
log_time: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
@@ -72,7 +127,13 @@ def log_rerun_data(
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||
log_time: The time to log the data in the "episode_time" timeline.
|
||||
If None, the current time is used in Rerun's default timeline.
|
||||
"""
|
||||
if log_time is None:
|
||||
log_time = time.perf_counter()
|
||||
rr.set_time("episode_time", timestamp=log_time)
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
@@ -80,15 +141,41 @@ def log_rerun_data(
|
||||
key = k if str(k).startswith(OBS_PREFIX) else f"{OBS_STR}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
# Convert samples x channels -> channels x samples when needed
|
||||
elif arr.ndim == 2 and arr.shape[1] < arr.shape[0]:
|
||||
arr = np.transpose(arr, (1, 0))
|
||||
|
||||
if arr.ndim == 1:
|
||||
for i, vi in enumerate(arr):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
elif arr.ndim == 2:
|
||||
for i, channel_arr in enumerate(arr):
|
||||
rr.send_columns(
|
||||
"audio/"
|
||||
+ key
|
||||
+ f"_channel_{i}", # TODO(CarolinePascal): Get actual channel number/name
|
||||
indexes=[
|
||||
rr.TimeColumn(
|
||||
"episode_time",
|
||||
timestamp=log_time
|
||||
+ np.linspace(
|
||||
-DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
0,
|
||||
len(channel_arr),
|
||||
endpoint=False,
|
||||
),
|
||||
)
|
||||
],
|
||||
columns=rr.Scalars.columns(scalars=channel_arr),
|
||||
)
|
||||
elif arr.ndim == 3:
|
||||
rr.log(key, rr.Image(arr), static=True)
|
||||
else:
|
||||
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
|
||||
rr.log(key, entity=img_entity, static=True)
|
||||
@@ -100,13 +187,13 @@ def log_rerun_data(
|
||||
key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}"
|
||||
|
||||
if _is_scalar(v):
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
rr.log("data/" + key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
if v.ndim == 1:
|
||||
for i, vi in enumerate(v):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
else:
|
||||
# Fall back to flattening higher-dimensional arrays
|
||||
flat = v.flatten()
|
||||
for i, vi in enumerate(flat):
|
||||
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
rr.log("data/" + f"{key}_{i}", rr.Scalars(float(vi)))
|
||||
|
||||
@@ -144,18 +144,12 @@ def test_async_inference_e2e(monkeypatch):
|
||||
client = RobotClient(client_config)
|
||||
assert client.start(), "Client failed initial handshake with the server"
|
||||
|
||||
# Track action chunks received and verify device type
|
||||
action_chunks_received = {"count": 0, "actions_on_cpu": True}
|
||||
# Track action chunks received without modifying RobotClient
|
||||
action_chunks_received = {"count": 0}
|
||||
original_aggregate = client._aggregate_action_queues
|
||||
|
||||
def counting_aggregate(*args, **kwargs):
|
||||
action_chunks_received["count"] += 1
|
||||
# Check that all received actions are on CPU
|
||||
if args:
|
||||
for timed_action in args[0]: # args[0] is the list of TimedAction
|
||||
action_tensor = timed_action.get_action()
|
||||
if action_tensor.device.type != "cpu":
|
||||
action_chunks_received["actions_on_cpu"] = False
|
||||
return original_aggregate(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate)
|
||||
|
||||
@@ -57,6 +57,8 @@ def _check_component_availability(component_type, available_components, make_com
|
||||
print("\nNo physical device detected.")
|
||||
elif isinstance(e, ValueError) and "camera_index" in str(e):
|
||||
print("\nNo physical camera detected.")
|
||||
elif isinstance(e, ValueError) and "microphone_index" in str(e):
|
||||
print("\nNo physical microphone detected.")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@@ -26,16 +26,22 @@ from lerobot.datasets.compute_stats import (
|
||||
compute_episode_stats,
|
||||
estimate_num_samples,
|
||||
get_feature_stats,
|
||||
sample_audio_from_data,
|
||||
sample_audio_from_path,
|
||||
sample_images,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_AUDIO, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def mock_load_image_as_numpy(path, dtype, channel_first):
|
||||
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
||||
|
||||
|
||||
def mock_load_audio(path):
|
||||
return np.ones((16000, 2), dtype=np.float32)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_array():
|
||||
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
@@ -73,6 +79,25 @@ def test_sample_images(mock_load):
|
||||
assert len(images) == estimate_num_samples(100)
|
||||
|
||||
|
||||
@patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio)
|
||||
def test_sample_audio_from_path(mock_load):
|
||||
audio_path = "audio.wav"
|
||||
audio_samples = sample_audio_from_path(audio_path)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_sample_audio_from_data():
|
||||
audio_data = np.ones((16000, 2), dtype=np.float32)
|
||||
audio_samples = sample_audio_from_data(audio_data)
|
||||
assert isinstance(audio_samples, np.ndarray)
|
||||
assert audio_samples.shape[1] == 2
|
||||
assert audio_samples.dtype == np.float32
|
||||
assert len(audio_samples) == estimate_num_samples(16000)
|
||||
|
||||
|
||||
def test_get_feature_stats_images():
|
||||
data = np.random.rand(100, 3, 32, 32)
|
||||
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
||||
@@ -81,6 +106,14 @@ def test_get_feature_stats_images():
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_audio():
|
||||
data = np.random.uniform(-1, 1, (16000, 2))
|
||||
stats = get_feature_stats(data, axis=0, keepdims=True)
|
||||
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
||||
np.testing.assert_equal(stats["count"], np.array([16000]))
|
||||
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
||||
|
||||
|
||||
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
||||
expected = {
|
||||
"min": np.array([[1, 2, 3]]),
|
||||
@@ -145,20 +178,27 @@ def test_get_feature_stats_single_value():
|
||||
def test_compute_episode_stats():
|
||||
episode_data = {
|
||||
OBS_IMAGE: [f"image_{i}.jpg" for i in range(100)],
|
||||
OBS_AUDIO: "audio.wav",
|
||||
OBS_STATE: np.random.rand(100, 10),
|
||||
}
|
||||
features = {
|
||||
OBS_IMAGE: {"dtype": "image"},
|
||||
OBS_AUDIO: {"dtype": "audio"},
|
||||
OBS_STATE: {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
with (
|
||||
patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy),
|
||||
patch("lerobot.datasets.compute_stats.load_audio_from_path", side_effect=mock_load_audio),
|
||||
):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert OBS_IMAGE in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == 100
|
||||
assert stats[OBS_STATE]["count"].item() == 100
|
||||
assert OBS_IMAGE in stats and OBS_AUDIO in stats and OBS_STATE in stats
|
||||
assert stats[OBS_IMAGE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_AUDIO]["count"].item() == estimate_num_samples(16000)
|
||||
assert stats[OBS_STATE]["count"].item() == estimate_num_samples(100)
|
||||
assert stats[OBS_IMAGE]["mean"].shape == (3, 1, 1)
|
||||
assert stats[OBS_AUDIO]["mean"].shape == (1, 2)
|
||||
|
||||
|
||||
def test_assert_type_and_shape_valid():
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from soundfile import write
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
@@ -37,6 +38,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
_encode_video_worker,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_CHUNK_DURATION,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
@@ -49,7 +51,13 @@ from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
DUMMY_CHW,
|
||||
DUMMY_HWC,
|
||||
DUMMY_REPO_ID,
|
||||
)
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -70,6 +78,36 @@ def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset_le_kiwi(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, robot_type="lekiwi")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def audio_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"audio": {
|
||||
"dtype": "audio",
|
||||
"shape": (1, DUMMY_AUDIO_CHANNELS),
|
||||
"names": [
|
||||
"channels",
|
||||
],
|
||||
"info": {"sample_rate": DEFAULT_SAMPLE_RATE},
|
||||
}
|
||||
}
|
||||
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
|
||||
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
@@ -411,6 +449,78 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_shape(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS, 99
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_array_wrong_channels_number(audio_dataset_le_kiwi):
|
||||
dataset = audio_dataset_le_kiwi
|
||||
with pytest.raises(ValueError):
|
||||
dataset.add_frame(
|
||||
{"audio": np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), 99)},
|
||||
task="Dummy task",
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_audio_file(audio_dataset):
|
||||
dataset = audio_dataset
|
||||
dataset.add_frame(
|
||||
{
|
||||
"audio": np.random.rand(
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS
|
||||
)
|
||||
},
|
||||
task="Dummy task",
|
||||
)
|
||||
# Create the audio file that should be created in the background by the Microphone class
|
||||
for audio_key in dataset.meta.audio_keys:
|
||||
fpath = dataset._get_raw_audio_file_path(0, audio_key)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
write(
|
||||
fpath,
|
||||
np.random.rand(int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS),
|
||||
DEFAULT_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["audio"].shape == torch.Size(
|
||||
(
|
||||
DUMMY_AUDIO_CHANNELS,
|
||||
int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
@@ -450,6 +560,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
audio_keys = dataset.meta.audio_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -492,6 +603,11 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# test c,h,w
|
||||
assert item[key].shape[0] == 3, f"{key}"
|
||||
|
||||
for key in audio_keys:
|
||||
assert item[key].dtype == torch.float32, f"{key}"
|
||||
assert item[key].max() <= 1.0, f"{key}"
|
||||
assert item[key].min() >= -1.0, f"{key}"
|
||||
|
||||
if delta_timestamps is not None:
|
||||
# test missing keys in delta_timestamps
|
||||
for key in delta_timestamps:
|
||||
|
||||
Vendored
+13
@@ -40,5 +40,18 @@ DUMMY_VIDEO_INFO = {
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
DUMMY_MICROPHONE_FEATURES = {
|
||||
"laptop": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||
"phone": {"dtype": "audio", "shape": (1, 2), "names": ["channels"], "info": None},
|
||||
}
|
||||
DEFAULT_SAMPLE_RATE = 48000
|
||||
DUMMY_AUDIO_CHANNELS = 2
|
||||
DUMMY_AUDIO_INFO = {
|
||||
"has_audio": True,
|
||||
"audio.sample_rate": DEFAULT_SAMPLE_RATE,
|
||||
"audio.codec": "aac",
|
||||
"audio.channels": DUMMY_AUDIO_CHANNELS,
|
||||
"audio.channel_layout": "stereo",
|
||||
}
|
||||
DUMMY_CHW = (3, 96, 128)
|
||||
DUMMY_HWC = (96, 128, 3)
|
||||
|
||||
Vendored
+18
-1
@@ -28,6 +28,7 @@ from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -42,6 +43,7 @@ from lerobot.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MICROPHONE_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
@@ -130,6 +132,7 @@ def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
@@ -141,6 +144,7 @@ def features_factory():
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**audio_features,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
@@ -157,16 +161,19 @@ def info_factory(features_factory):
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_audio: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
audio_path: str = DEFAULT_AUDIO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
audio_features: dict = DUMMY_MICROPHONE_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
features = features_factory(motor_features, camera_features, audio_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
@@ -174,6 +181,7 @@ def info_factory(features_factory):
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_audio": total_audio,
|
||||
"chunks_size": chunks_size,
|
||||
"data_files_size_in_mb": data_files_size_in_mb,
|
||||
"video_files_size_in_mb": video_files_size_in_mb,
|
||||
@@ -181,6 +189,7 @@ def info_factory(features_factory):
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"audio_path": audio_path,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
@@ -204,6 +213,14 @@ def stats_factory():
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
elif dtype == "audio":
|
||||
stats[key] = {
|
||||
"mean": np.full((shape[0],), 0.0, dtype=np.float32).tolist(),
|
||||
"max": np.full((shape[0],), 1, dtype=np.float32).tolist(),
|
||||
"min": np.full((shape[0],), -1, dtype=np.float32).tolist(),
|
||||
"std": np.full((shape[0],), 0.5, dtype=np.float32).tolist(),
|
||||
"count": [10],
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
|
||||
@@ -0,0 +1,532 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundfile import read
|
||||
|
||||
from lerobot.microphones.portaudio.configuration_portaudio import PortAudioMicrophoneConfig
|
||||
from lerobot.microphones.portaudio.interface_sounddevice_sdk import (
|
||||
FakeSounddeviceSDKAdapter,
|
||||
SounddeviceSDKAdapter,
|
||||
)
|
||||
from lerobot.microphones.portaudio.microphone_portaudio import PortAudioMicrophone
|
||||
from lerobot.microphones.utils import async_microphones_start_recording, async_microphones_stop_recording
|
||||
from lerobot.utils.errors import (
|
||||
DeviceAlreadyConnectedError,
|
||||
DeviceAlreadyRecordingError,
|
||||
DeviceNotConnectedError,
|
||||
DeviceNotRecordingError,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
MODULE_PATH = "lerobot.microphones.portaudio.microphone_portaudio"
|
||||
RECORDING_DURATION = 1.0
|
||||
|
||||
LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS = (
|
||||
os.getenv("LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS", "False").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sdk():
|
||||
"""Fixture to provide either real or fake SDK based on environment variable."""
|
||||
if LEROBOT_USE_REAL_PORTAUDIO_MICROPHONE_TESTS:
|
||||
return SounddeviceSDKAdapter()
|
||||
else:
|
||||
return FakeSounddeviceSDKAdapter()
|
||||
|
||||
|
||||
# Configuration Tests
|
||||
|
||||
|
||||
def test_config_creation():
|
||||
"""Test creating a valid configuration."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000, channels=[1, 2])
|
||||
assert config.microphone_index == 0
|
||||
assert config.sample_rate == 48000
|
||||
assert config.channels == [1, 2]
|
||||
|
||||
|
||||
def test_config_creation_missing_microphone_index():
|
||||
"""Test creating a configuration with missing microphone index."""
|
||||
with pytest.raises(TypeError):
|
||||
PortAudioMicrophoneConfig(sample_rate=48000, channels=[1, 2])
|
||||
|
||||
|
||||
def test_config_creation_missing_sample_rate():
|
||||
"""Test creating a configuration with missing sample rate."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, channels=[1, 2])
|
||||
assert config.sample_rate is None
|
||||
|
||||
|
||||
def test_config_creation_missing_channels():
|
||||
"""Test creating a configuration with missing channels."""
|
||||
config = PortAudioMicrophoneConfig(microphone_index=0, sample_rate=48000)
|
||||
assert config.channels is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_config(test_sdk):
|
||||
"""Fixture to provide a default configuration for input devices."""
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
return PortAudioMicrophoneConfig(
|
||||
microphone_index=device_info["index"],
|
||||
sample_rate=device_info["default_samplerate"],
|
||||
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||
)
|
||||
|
||||
|
||||
# Microphone Tests
|
||||
|
||||
|
||||
def test_find_microphones(test_sdk):
|
||||
"""Test finding microphones."""
|
||||
microphones = PortAudioMicrophone.find_microphones(sounddevice_sdk=test_sdk)
|
||||
|
||||
for microphone in microphones:
|
||||
assert isinstance(microphone["index"], int)
|
||||
assert isinstance(microphone["name"], str)
|
||||
assert isinstance(microphone["sample_rate"], int)
|
||||
assert isinstance(microphone["channels"], np.ndarray)
|
||||
assert len(microphone["channels"]) > 0
|
||||
|
||||
|
||||
def test_init_defaults(default_config, test_sdk):
|
||||
"""Test microphone initialization with defaults."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert microphone is not None
|
||||
assert microphone.microphone_index == device_info["index"]
|
||||
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
|
||||
|
||||
def test_connect_success(default_config, test_sdk):
|
||||
"""Test successful connection."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_connect_empty_config(default_config, test_sdk):
|
||||
"""Test connection with empty config values."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = None
|
||||
config.channels = None
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert microphone.sample_rate == device_info["default_samplerate"]
|
||||
np.testing.assert_array_equal(microphone.channels, np.arange(device_info["max_input_channels"]) + 1)
|
||||
|
||||
|
||||
def test_connect_already_connected(default_config, test_sdk):
|
||||
"""Test connecting when already connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(DeviceAlreadyConnectedError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_device(test_sdk):
|
||||
"""Test connecting with invalid device (output device)."""
|
||||
device_info = test_sdk.query_devices(kind="output")
|
||||
config = PortAudioMicrophoneConfig(
|
||||
microphone_index=device_info["index"],
|
||||
sample_rate=device_info["default_samplerate"],
|
||||
channels=np.arange(device_info["max_input_channels"]) + 1,
|
||||
)
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_index(default_config, test_sdk):
|
||||
"""Test connecting with invalid device index."""
|
||||
config = deepcopy(default_config)
|
||||
config.microphone_index = -1
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_invalid_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with invalid sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = -1
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_connect_float_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with float sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = int(config.sample_rate) - 0.5
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
assert isinstance(microphone.sample_rate, int)
|
||||
assert microphone.sample_rate == int(config.sample_rate)
|
||||
|
||||
|
||||
def test_connect_lower_sample_rate(default_config, test_sdk):
|
||||
"""Test connecting with lower sample rate."""
|
||||
config = deepcopy(default_config)
|
||||
config.sample_rate = 1000 # Lowest possible sample rate
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
microphone.connect()
|
||||
assert microphone.sample_rate == 1000
|
||||
|
||||
|
||||
def test_connect_invalid_channels(default_config, test_sdk):
|
||||
"""Test connecting with invalid channels."""
|
||||
config = deepcopy(default_config)
|
||||
config.channels = np.append(default_config.channels, -1)
|
||||
microphone = PortAudioMicrophone(config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
microphone.connect()
|
||||
|
||||
|
||||
def test_disconnect_success(default_config, test_sdk):
|
||||
"""Test successful disconnection."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_disconnect_not_connected(default_config, test_sdk):
|
||||
"""Test disconnecting when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.disconnect()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_recording_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful recording start."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_recording_not_connected(default_config, test_sdk, multiprocessing):
|
||||
"""Test starting recording when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_recording_already_recording(default_config, test_sdk, multiprocessing):
|
||||
"""Test starting recording when already recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
with pytest.raises(DeviceAlreadyRecordingError):
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing start."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert microphone.is_writing
|
||||
assert (tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_start_writing_file_already_exists_no_overwrite(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test writing with file that already exists."""
|
||||
(tmp_path / "test.wav").touch()
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
microphone.start_recording(
|
||||
output_file=tmp_path / "test.wav", multiprocessing=multiprocessing, overwrite=False
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_stop_recording_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful recording stop."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.stop_recording()
|
||||
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_stop_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing stop."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.stop_recording()
|
||||
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
assert (tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
def test_stop_recording_not_connected(default_config, test_sdk):
|
||||
"""Test stopping recording when not connected."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
|
||||
with pytest.raises(DeviceNotConnectedError):
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
def test_stop_recording_not_recording(default_config, test_sdk):
|
||||
"""Test stopping recording when not recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
|
||||
with pytest.raises(DeviceNotRecordingError):
|
||||
microphone.stop_recording()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_disconnect_while_recording(default_config, test_sdk, multiprocessing):
|
||||
"""Test disconnecting while recording."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_disconnect_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test disconnecting while writing."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
microphone.disconnect()
|
||||
|
||||
assert not microphone.is_connected
|
||||
assert not microphone.is_recording
|
||||
assert not microphone.is_writing
|
||||
assert Path(tmp_path / "test.wav").exists()
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_read_success(default_config, test_sdk, multiprocessing):
|
||||
"""Test successful reading of audio data."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
data = microphone.read()
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert data is not None
|
||||
assert data.shape[1] == len(default_config.channels)
|
||||
assert (
|
||||
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_writing_success(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test successful writing to file."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
microphone.stop_recording()
|
||||
|
||||
data, samplerate = read(tmp_path / "test.wav")
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert samplerate == default_config.sample_rate
|
||||
assert data.shape[1] == len(default_config.channels)
|
||||
assert (
|
||||
abs(data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multiprocessing", [True, False])
|
||||
def test_read_while_writing(tmp_path, default_config, test_sdk, multiprocessing):
|
||||
"""Test reading while writing."""
|
||||
microphone = PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk)
|
||||
microphone.connect()
|
||||
microphone.start_recording(output_file=tmp_path / "test.wav", multiprocessing=multiprocessing)
|
||||
|
||||
precise_sleep(RECORDING_DURATION)
|
||||
|
||||
read_data = microphone.read()
|
||||
microphone.stop_recording()
|
||||
|
||||
writing_data, _ = read(tmp_path / "test.wav")
|
||||
|
||||
device_info = test_sdk.query_devices(kind="input")
|
||||
assert (
|
||||
abs(writing_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
assert (
|
||||
abs(read_data.shape[0] - RECORDING_DURATION * default_config.sample_rate)
|
||||
<= 2 * default_config.sample_rate * device_info["default_low_input_latency"]
|
||||
)
|
||||
|
||||
(tmp_path / "test.wav").unlink()
|
||||
|
||||
|
||||
def test_async_start_recording(default_config, test_sdk):
|
||||
"""Test async recording start."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_async_start_writing(tmp_path, default_config, test_sdk):
|
||||
"""Test async writing start."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(
|
||||
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||
)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert microphone.is_writing
|
||||
assert Path(tmp_path / "test_1.wav").exists()
|
||||
assert Path(tmp_path / "test_2.wav").exists()
|
||||
|
||||
(tmp_path / "test_1.wav").unlink()
|
||||
(tmp_path / "test_2.wav").unlink()
|
||||
|
||||
|
||||
def test_async_stop_recording(default_config, test_sdk):
|
||||
"""Test async recording stop."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(microphones)
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
|
||||
|
||||
def test_async_stop_writing(tmp_path, default_config, test_sdk):
|
||||
"""Test async writing stop."""
|
||||
microphones = {
|
||||
"microphone_1": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
"microphone_2": PortAudioMicrophone(default_config, sounddevice_sdk=test_sdk),
|
||||
}
|
||||
for microphone in microphones.values():
|
||||
microphone.connect()
|
||||
|
||||
async_microphones_start_recording(
|
||||
microphones, output_files=[tmp_path / "test_1.wav", tmp_path / "test_2.wav"]
|
||||
)
|
||||
async_microphones_stop_recording(microphones)
|
||||
|
||||
for microphone in microphones.values():
|
||||
assert not microphone.is_recording
|
||||
assert microphone.is_connected
|
||||
assert not microphone.is_writing
|
||||
assert Path(tmp_path / "test_1.wav").exists()
|
||||
assert Path(tmp_path / "test_2.wav").exists()
|
||||
|
||||
(tmp_path / "test_1.wav").unlink()
|
||||
(tmp_path / "test_2.wav").unlink()
|
||||
@@ -0,0 +1,508 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
from multiprocessing import Event, Process, Queue
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.utils.shared_array import SharedArray
|
||||
|
||||
|
||||
def writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""Writer process that continuously writes data to shared array."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Wait for all processes to be ready
|
||||
barrier.wait()
|
||||
|
||||
write_count = 0
|
||||
while not stop_event.is_set() and write_count < 10:
|
||||
# Generate unique data for this process and write iteration
|
||||
data = np.full((5, 2), process_id * 100 + write_count, dtype=np.float32)
|
||||
|
||||
try:
|
||||
shared_array.write(local_array, data)
|
||||
data_queue.put(f"writer_{process_id}_wrote_{write_count}")
|
||||
write_count += 1
|
||||
time.sleep(0.01) # Small delay to allow race conditions
|
||||
except IndexError:
|
||||
# Array is full, stop writing
|
||||
break
|
||||
|
||||
|
||||
def reader_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""Reader process that continuously reads data from shared array."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Wait for all processes to be ready
|
||||
barrier.wait()
|
||||
|
||||
read_count = 0
|
||||
while not stop_event.is_set() and read_count < 5:
|
||||
time.sleep(0.02) # Allow some writes to accumulate
|
||||
|
||||
data = shared_array.read(local_array, flush=True)
|
||||
data_queue.put(f"reader_{process_id}_read_{len(data)}_items")
|
||||
read_count += 1
|
||||
|
||||
|
||||
def stress_writer_process(shared_array, data_queue, stop_event, barrier, process_id):
|
||||
"""High-frequency writer process for stress testing."""
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
barrier.wait()
|
||||
|
||||
write_count = 0
|
||||
while not stop_event.is_set() and write_count < 50:
|
||||
# Write single row at a time for more frequent operations
|
||||
data = np.array([[process_id, write_count]], dtype=np.float32)
|
||||
|
||||
try:
|
||||
shared_array.write(local_array, data)
|
||||
write_count += 1
|
||||
# No sleep - stress test
|
||||
except IndexError:
|
||||
break
|
||||
|
||||
data_queue.put(f"stress_writer_{process_id}_completed_{write_count}")
|
||||
|
||||
|
||||
# Basic functionality tests
|
||||
|
||||
|
||||
def test_shared_array_creation():
|
||||
"""Test basic SharedArray creation and properties."""
|
||||
shape = (100, 4)
|
||||
dtype = np.float32
|
||||
|
||||
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||
|
||||
assert shared_array.shape == shape
|
||||
assert shared_array.dtype == dtype
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
# Clean up
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_local_array_access():
|
||||
"""Test getting local array instances."""
|
||||
shape = (50, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
assert local_array.shape == shape
|
||||
assert local_array.dtype == np.float32
|
||||
assert isinstance(local_array, np.ndarray)
|
||||
|
||||
# Test that we can get multiple local array instances
|
||||
local_array2 = shared_array.get_local_array()
|
||||
assert local_array2.shape == shape
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_write_and_read_single_process():
|
||||
"""Test basic write and read operations in single process."""
|
||||
shape = (20, 3)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Write some data
|
||||
data1 = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
shared_array.write(local_array, data1)
|
||||
|
||||
assert shared_array.read_index.value == 2
|
||||
|
||||
# Write more data
|
||||
data2 = np.array([[7, 8, 9]], dtype=np.float32)
|
||||
shared_array.write(local_array, data2)
|
||||
|
||||
assert shared_array.read_index.value == 3
|
||||
|
||||
# Read all data
|
||||
read_data = shared_array.read(local_array, flush=False)
|
||||
expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
|
||||
np.testing.assert_array_equal(read_data, expected)
|
||||
|
||||
# Read with flush
|
||||
read_data_flush = shared_array.read(local_array, flush=True)
|
||||
np.testing.assert_array_equal(read_data_flush, expected)
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_array_overflow():
|
||||
"""Test behavior when writing more data than array capacity."""
|
||||
shape = (5, 2) # Small array
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Fill the array
|
||||
data = np.ones((5, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
# Try to write more data - should raise IndexError
|
||||
with pytest.raises(ValueError):
|
||||
extra_data = np.ones((2, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, extra_data)
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_reset_functionality():
|
||||
"""Test the reset method."""
|
||||
shape = (10, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Write some data
|
||||
data = np.ones((3, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
assert shared_array.read_index.value == 3
|
||||
|
||||
# Reset
|
||||
shared_array.reset()
|
||||
assert shared_array.read_index.value == 0
|
||||
|
||||
# Read should return empty array
|
||||
read_data = shared_array.read(local_array, flush=False)
|
||||
assert len(read_data) == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
# Multi-process tests
|
||||
|
||||
|
||||
def test_single_writer_single_reader():
|
||||
"""Test basic writer-reader scenario with one process each."""
|
||||
shape = (100, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
barrier = multiprocessing.Barrier(2) # Writer + reader
|
||||
|
||||
# Start writer process
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
# Start reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
writer.start()
|
||||
reader.start()
|
||||
|
||||
# Let them run for a bit
|
||||
time.sleep(0.5)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
writer.join(timeout=2.0)
|
||||
reader.join(timeout=2.0)
|
||||
|
||||
# Verify both processes completed
|
||||
assert not writer.is_alive()
|
||||
assert not reader.is_alive()
|
||||
|
||||
# Check that we got messages from both processes
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
|
||||
assert len(writer_messages) > 0
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_multiple_writers_single_reader():
|
||||
"""Test multiple writers with single reader - check for race conditions."""
|
||||
shape = (200, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_writers = 3
|
||||
barrier = multiprocessing.Barrier(num_writers + 1) # Writers + reader
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple writer processes
|
||||
for i in range(num_writers):
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Start reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
processes.append(reader)
|
||||
reader.start()
|
||||
|
||||
# Let them run
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for all processes
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify we got messages from all processes
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
|
||||
# Should have messages from all writers
|
||||
assert len(writer_messages) >= num_writers
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_data_integrity_with_concurrent_access():
|
||||
"""Test that data integrity is maintained under concurrent access using standard reader/writer processes."""
|
||||
shape = (500, 2) # Use standard 2-column format
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
barrier = multiprocessing.Barrier(3) # 2 writers + 1 reader
|
||||
|
||||
# Start two writer processes
|
||||
writer1 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
writer2 = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, 2))
|
||||
|
||||
# Start one reader process
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, 1))
|
||||
|
||||
writer1.start()
|
||||
writer2.start()
|
||||
reader.start()
|
||||
|
||||
# Let them run for integrity test duration
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
writer1.join(timeout=3.0)
|
||||
writer2.join(timeout=3.0)
|
||||
reader.join(timeout=3.0)
|
||||
|
||||
# Verify all processes completed successfully
|
||||
assert not writer1.is_alive()
|
||||
assert not writer2.is_alive()
|
||||
assert not reader.is_alive()
|
||||
|
||||
# Verify data integrity by checking messages
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
writer1_messages = [msg for msg in messages if "writer_1_wrote" in msg]
|
||||
writer2_messages = [msg for msg in messages if "writer_2_wrote" in msg]
|
||||
reader_messages = [msg for msg in messages if "reader_1_read" in msg]
|
||||
|
||||
# Verify both writers wrote data
|
||||
assert len(writer1_messages) > 0
|
||||
assert len(writer2_messages) > 0
|
||||
# Verify reader read data
|
||||
assert len(reader_messages) > 0
|
||||
|
||||
# Verify the shared array is in a consistent state
|
||||
local_array = shared_array.get_local_array()
|
||||
final_data = shared_array.read(local_array, flush=False)
|
||||
|
||||
# Should have some data written by the writers
|
||||
assert len(final_data) >= 0 # Could be empty if reader flushed everything
|
||||
# Should not exceed array capacity
|
||||
assert len(final_data) <= shape[0]
|
||||
|
||||
# If there's data, verify it contains the expected writer signatures
|
||||
if len(final_data) > 0:
|
||||
# Data should contain values like 100, 101, 102... (writer 1) or 200, 201, 202... (writer 2)
|
||||
unique_values = np.unique(final_data.flatten())
|
||||
writer1_values = unique_values[(unique_values >= 100) & (unique_values < 200)]
|
||||
writer2_values = unique_values[(unique_values >= 200) & (unique_values < 300)]
|
||||
|
||||
# Should have data from at least one writer
|
||||
assert len(writer1_values) > 0 or len(writer2_values) > 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_stress_test_high_frequency_operations():
|
||||
"""Stress test with high frequency read/write operations."""
|
||||
shape = (1000, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_writers = 4
|
||||
barrier = multiprocessing.Barrier(num_writers)
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple high-frequency writers
|
||||
for i in range(num_writers):
|
||||
writer = Process(
|
||||
target=stress_writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1)
|
||||
)
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Let them run for stress test duration
|
||||
time.sleep(0.5)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for completion
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify all writers completed successfully
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
completed_messages = [msg for msg in messages if "completed" in msg]
|
||||
assert len(completed_messages) == num_writers
|
||||
|
||||
# Verify the shared array is in a consistent state
|
||||
local_array = shared_array.get_local_array()
|
||||
final_data = shared_array.read(local_array, flush=False)
|
||||
|
||||
# Should have some data written
|
||||
assert len(final_data) > 0
|
||||
# Should not exceed array capacity
|
||||
assert len(final_data) <= shape[0]
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_concurrent_readers():
|
||||
"""Test multiple concurrent readers with writers to ensure thread safety."""
|
||||
shape = (200, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
|
||||
data_queue = Queue()
|
||||
stop_event = Event()
|
||||
num_readers = 3
|
||||
num_writers = 2
|
||||
barrier = multiprocessing.Barrier(num_readers + num_writers)
|
||||
|
||||
processes = []
|
||||
|
||||
# Start multiple writer processes to generate data
|
||||
for i in range(num_writers):
|
||||
writer = Process(target=writer_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(writer)
|
||||
writer.start()
|
||||
|
||||
# Start multiple reader processes
|
||||
for i in range(num_readers):
|
||||
reader = Process(target=reader_process, args=(shared_array, data_queue, stop_event, barrier, i + 1))
|
||||
processes.append(reader)
|
||||
reader.start()
|
||||
|
||||
# Let them run to test concurrent access
|
||||
time.sleep(1.0)
|
||||
stop_event.set()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for process in processes:
|
||||
process.join(timeout=3.0)
|
||||
assert not process.is_alive()
|
||||
|
||||
# Verify all readers and writers completed
|
||||
messages = []
|
||||
while not data_queue.empty():
|
||||
messages.append(data_queue.get())
|
||||
|
||||
reader_messages = [msg for msg in messages if msg.startswith("reader_")]
|
||||
writer_messages = [msg for msg in messages if msg.startswith("writer_")]
|
||||
|
||||
# Should have messages from all readers and writers
|
||||
assert len(reader_messages) >= num_readers
|
||||
assert len(writer_messages) >= num_writers
|
||||
|
||||
# Verify different readers generated different messages (proving they ran concurrently)
|
||||
reader_ids = set()
|
||||
for msg in reader_messages:
|
||||
# Extract reader ID from message like "reader_1_read_5_items"
|
||||
parts = msg.split("_")
|
||||
if len(parts) >= 2:
|
||||
reader_ids.add(parts[1])
|
||||
|
||||
assert len(reader_ids) == num_readers # All readers should have participated
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_edge_case_empty_reads():
|
||||
"""Test reading from empty array and after flushes."""
|
||||
shape = (10, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=np.float32)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
# Read from empty array
|
||||
empty_data = shared_array.read(local_array, flush=False)
|
||||
assert len(empty_data) == 0
|
||||
|
||||
# Write some data
|
||||
data = np.ones((3, 2), dtype=np.float32)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
# Read with flush
|
||||
read_data = shared_array.read(local_array, flush=True)
|
||||
assert len(read_data) == 3
|
||||
|
||||
# Read again after flush - should be empty
|
||||
empty_again = shared_array.read(local_array, flush=False)
|
||||
assert len(empty_again) == 0
|
||||
|
||||
shared_array.delete()
|
||||
|
||||
|
||||
def test_different_dtypes():
|
||||
"""Test SharedArray with different numpy dtypes."""
|
||||
dtypes_to_test = [np.float32, np.float64, np.int32, np.int16]
|
||||
|
||||
for dtype in dtypes_to_test:
|
||||
shape = (20, 2)
|
||||
shared_array = SharedArray(shape=shape, dtype=dtype)
|
||||
local_array = shared_array.get_local_array()
|
||||
|
||||
assert local_array.dtype == dtype
|
||||
|
||||
# Write and read data of this dtype
|
||||
data = np.ones((5, 2), dtype=dtype)
|
||||
shared_array.write(local_array, data)
|
||||
|
||||
read_data = shared_array.read(local_array, flush=True)
|
||||
assert read_data.dtype == dtype
|
||||
assert len(read_data) == 5
|
||||
|
||||
shared_array.delete()
|
||||
+8
-1
@@ -20,7 +20,7 @@ from functools import wraps
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot import available_cameras, available_microphones, available_motors, available_robots
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
@@ -33,6 +33,10 @@ TEST_CAMERA_TYPES = []
|
||||
for camera_type in available_cameras:
|
||||
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
|
||||
|
||||
TEST_MICROPHONE_TYPES = []
|
||||
for microphone_type in available_microphones:
|
||||
TEST_MICROPHONE_TYPES += [(microphone_type, True), (microphone_type, False)]
|
||||
|
||||
TEST_MOTOR_TYPES = []
|
||||
for motor_type in available_motors:
|
||||
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
|
||||
@@ -41,6 +45,9 @@ for motor_type in available_motors:
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
|
||||
|
||||
# Microphone indices used for connecting physical microphones
|
||||
MICROPHONE_INDEX = int(os.environ.get("LEROBOT_TEST_MICROPHONE_INDEX", 0))
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||
DYNAMIXEL_MOTORS = {
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
|
||||
@@ -37,6 +37,14 @@ def mock_rerun(monkeypatch):
|
||||
def __init__(self, value):
|
||||
self.value = float(value)
|
||||
|
||||
@staticmethod
|
||||
def columns(scalars):
|
||||
return DummyScalarsColumn(scalars)
|
||||
|
||||
class DummyScalarsColumn:
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
class DummyImage:
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
@@ -47,12 +55,46 @@ def mock_rerun(monkeypatch):
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
def dummy_send_columns(key, indexes, columns, **kwargs):
|
||||
calls.append((key, columns, kwargs))
|
||||
|
||||
def dummy_time_column(timeline, timestamp):
|
||||
return timestamp
|
||||
|
||||
def dummy_set_time(timeline, timestamp):
|
||||
return None
|
||||
|
||||
class DummyTimeSeriesView:
|
||||
def __call__(self, origin, plot_legend=None):
|
||||
return None
|
||||
|
||||
class DummySpatial2DView:
|
||||
def __call__(self, origin):
|
||||
return None
|
||||
|
||||
class DummyGrid:
|
||||
def __call__(self, *args):
|
||||
return None
|
||||
|
||||
class DummyPlotLegend:
|
||||
def __call__(self, visible=True):
|
||||
return None
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
TimeColumn=dummy_time_column,
|
||||
send_columns=dummy_send_columns,
|
||||
set_time=dummy_set_time,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
blueprint=SimpleNamespace(
|
||||
TimeSeriesView=DummyTimeSeriesView,
|
||||
Spatial2DView=DummySpatial2DView,
|
||||
Grid=DummyGrid,
|
||||
PlotLegend=DummyPlotLegend,
|
||||
),
|
||||
)
|
||||
|
||||
# Inject fake module into sys.modules
|
||||
@@ -87,7 +129,7 @@ def _kwargs_for(calls, key):
|
||||
raise KeyError(f"Key {key} not found in calls: {calls}")
|
||||
|
||||
|
||||
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
def test_log_rerun_data_envtransition_scalars_image_audio(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
# Build EnvTransition dict
|
||||
@@ -95,6 +137,8 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
f"{OBS_STATE}.temperature": np.float32(25.0),
|
||||
# CHW image should be converted to HWC for rr.Image
|
||||
"observation.camera": np.zeros((3, 10, 20), dtype=np.uint8),
|
||||
# Multiple channels audio data should be split into separate channels and logged as rr.Scalars.columns
|
||||
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||
}
|
||||
act = {
|
||||
"action.throttle": 0.7,
|
||||
@@ -117,25 +161,27 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
# - action.throttle -> Scalars
|
||||
# - action.vector_0, action.vector_1 -> Scalars
|
||||
expected_keys = {
|
||||
f"{OBS_STATE}.temperature",
|
||||
"data/" + f"{OBS_STATE}.temperature",
|
||||
"observation.camera",
|
||||
"action.throttle",
|
||||
"action.vector_0",
|
||||
"action.vector_1",
|
||||
"data/action.throttle",
|
||||
"data/action.vector_0",
|
||||
"data/action.vector_1",
|
||||
"audio/observation.audio_channel_0",
|
||||
"audio/observation.audio_channel_1",
|
||||
}
|
||||
assert set(_keys(calls)) == expected_keys
|
||||
|
||||
# Check scalar types and values
|
||||
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
|
||||
temp_obj = _obj_for(calls, f"data/{OBS_STATE}.temperature")
|
||||
assert type(temp_obj).__name__ == "DummyScalar"
|
||||
assert temp_obj.value == pytest.approx(25.0)
|
||||
|
||||
throttle_obj = _obj_for(calls, "action.throttle")
|
||||
throttle_obj = _obj_for(calls, "data/action.throttle")
|
||||
assert type(throttle_obj).__name__ == "DummyScalar"
|
||||
assert throttle_obj.value == pytest.approx(0.7)
|
||||
|
||||
v0 = _obj_for(calls, "action.vector_0")
|
||||
v1 = _obj_for(calls, "action.vector_1")
|
||||
v0 = _obj_for(calls, "data/action.vector_0")
|
||||
v1 = _obj_for(calls, "data/action.vector_1")
|
||||
assert type(v0).__name__ == "DummyScalar"
|
||||
assert type(v1).__name__ == "DummyScalar"
|
||||
assert v0.value == pytest.approx(1.0)
|
||||
@@ -147,6 +193,14 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
|
||||
assert img_obj.arr.shape == (10, 20, 3) # transposed
|
||||
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
|
||||
|
||||
# Check audio handling: split channels + rr.Scalars.columns
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
|
||||
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
@@ -157,6 +211,8 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
"temp": 1.5,
|
||||
# Already HWC image => should stay as-is
|
||||
"img": np.zeros((5, 6, 3), dtype=np.uint8),
|
||||
# Multiple channels audio data should be split into separate channels
|
||||
"audio": np.zeros((100, 2), dtype=np.float32),
|
||||
"none": None, # should be skipped
|
||||
}
|
||||
act_plain = {
|
||||
@@ -170,22 +226,24 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
|
||||
# Expected keys with auto-prefixes
|
||||
expected = {
|
||||
"observation.temp",
|
||||
"data/observation.temp",
|
||||
"observation.img",
|
||||
"action.throttle",
|
||||
"action.vec_0",
|
||||
"action.vec_1",
|
||||
"action.vec_2",
|
||||
"data/action.throttle",
|
||||
"data/action.vec_0",
|
||||
"data/action.vec_1",
|
||||
"data/action.vec_2",
|
||||
"audio/observation.audio_channel_0",
|
||||
"audio/observation.audio_channel_1",
|
||||
}
|
||||
logged = set(_keys(calls))
|
||||
assert logged == expected
|
||||
|
||||
# Scalars
|
||||
t = _obj_for(calls, "observation.temp")
|
||||
t = _obj_for(calls, "data/observation.temp")
|
||||
assert type(t).__name__ == "DummyScalar"
|
||||
assert t.value == pytest.approx(1.5)
|
||||
|
||||
throttle = _obj_for(calls, "action.throttle")
|
||||
throttle = _obj_for(calls, "data/action.throttle")
|
||||
assert type(throttle).__name__ == "DummyScalar"
|
||||
assert throttle.value == pytest.approx(0.3)
|
||||
|
||||
@@ -197,25 +255,39 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
|
||||
|
||||
# Vectors
|
||||
for i, val in enumerate([9, 8, 7]):
|
||||
o = _obj_for(calls, f"action.vec_{i}")
|
||||
o = _obj_for(calls, f"data/action.vec_{i}")
|
||||
assert type(o).__name__ == "DummyScalar"
|
||||
assert o.value == pytest.approx(val)
|
||||
|
||||
# Audio
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
|
||||
def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
vu, calls = mock_rerun
|
||||
|
||||
vu.log_rerun_data(
|
||||
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
|
||||
observation={
|
||||
"observation.temp": 10.0,
|
||||
"observation.gray": np.zeros((8, 8, 1), dtype=np.uint8),
|
||||
"observation.audio": np.zeros((100, 2), dtype=np.float32),
|
||||
},
|
||||
action={"action.a": 1.0},
|
||||
)
|
||||
|
||||
keys = set(_keys(calls))
|
||||
assert "observation.temp" in keys
|
||||
assert "data/observation.temp" in keys
|
||||
assert "observation.gray" in keys
|
||||
assert "action.a" in keys
|
||||
assert "data/action.a" in keys
|
||||
assert "audio/observation.audio_channel_0" in keys
|
||||
assert "audio/observation.audio_channel_1" in keys
|
||||
|
||||
temp = _obj_for(calls, "observation.temp")
|
||||
temp = _obj_for(calls, "data/observation.temp")
|
||||
assert type(temp).__name__ == "DummyScalar"
|
||||
assert temp.value == pytest.approx(10.0)
|
||||
|
||||
@@ -224,6 +296,13 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
assert img.arr.shape == (8, 8, 1) # remains HWC
|
||||
assert _kwargs_for(calls, "observation.gray").get("static", False) is True
|
||||
|
||||
a = _obj_for(calls, "action.a")
|
||||
a = _obj_for(calls, "data/action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
|
||||
audio_obj_0 = _obj_for(calls, "audio/observation.audio_channel_0")
|
||||
audio_obj_1 = _obj_for(calls, "audio/observation.audio_channel_1")
|
||||
assert type(audio_obj_0).__name__ == "DummyScalarsColumn"
|
||||
assert type(audio_obj_1).__name__ == "DummyScalarsColumn"
|
||||
assert audio_obj_0.values.shape == (100,)
|
||||
assert audio_obj_1.values.shape == (100,)
|
||||
|
||||
Reference in New Issue
Block a user