feat(audio in policies): adding audio as a input feature in policies

This commit is contained in:
CarolinePascal
2025-04-28 19:20:05 +02:00
parent bf8ede852d
commit 926184110b
3 changed files with 9 additions and 0 deletions
+4
View File
@@ -151,6 +151,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
return {}
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def audio_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
@property
def action_feature(self) -> PolicyFeature | None:
if not self.output_features:
+1
View File
@@ -20,6 +20,7 @@ from enum import Enum
class FeatureType(str, Enum):
STATE = "STATE"
VISUAL = "VISUAL"
AUDIO = "AUDIO"
ENV = "ENV"
ACTION = "ACTION"
REWARD = "REWARD"
+4
View File
@@ -741,6 +741,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif ft["dtype"] == "audio":
type = FeatureType.AUDIO
if len(shape) != 2:
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
elif key == OBS_ENV_STATE:
type = FeatureType.ENV
elif key.startswith(OBS_STR):