From 926184110bde983228c0234d70ca824e12b73f7a Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 28 Apr 2025 19:20:05 +0200 Subject: [PATCH] feat(audio in policies): adding audio as a input feature in policies --- src/lerobot/configs/policies.py | 4 ++++ src/lerobot/configs/types.py | 1 + src/lerobot/datasets/utils.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7f326b70b..11494f0fd 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -151,6 +151,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno return {} return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} + @property + def audio_features(self) -> dict[str, PolicyFeature]: + return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO} + @property def action_feature(self) -> PolicyFeature | None: if not self.output_features: diff --git a/src/lerobot/configs/types.py b/src/lerobot/configs/types.py index 18359ef05..228195399 100644 --- a/src/lerobot/configs/types.py +++ b/src/lerobot/configs/types.py @@ -20,6 +20,7 @@ from enum import Enum class FeatureType(str, Enum): STATE = "STATE" VISUAL = "VISUAL" + AUDIO = "AUDIO" ENV = "ENV" ACTION = "ACTION" REWARD = "REWARD" diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 34197fa6e..e5d0248e7 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -741,6 +741,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) shape = (shape[2], shape[0], shape[1]) + elif ft["dtype"] == "audio": + type = FeatureType.AUDIO + if len(shape) != 2: + raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})") elif key == OBS_ENV_STATE: type = FeatureType.ENV elif key.startswith(OBS_STR):