mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
feat(audio in policies): adding audio as a input feature in policies
This commit is contained in:
@@ -151,6 +151,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
return {}
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
if not self.output_features:
|
||||
|
||||
@@ -20,6 +20,7 @@ from enum import Enum
|
||||
class FeatureType(str, Enum):
|
||||
STATE = "STATE"
|
||||
VISUAL = "VISUAL"
|
||||
AUDIO = "AUDIO"
|
||||
ENV = "ENV"
|
||||
ACTION = "ACTION"
|
||||
REWARD = "REWARD"
|
||||
|
||||
@@ -741,6 +741,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif ft["dtype"] == "audio":
|
||||
type = FeatureType.AUDIO
|
||||
if len(shape) != 2:
|
||||
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||
elif key == OBS_ENV_STATE:
|
||||
type = FeatureType.ENV
|
||||
elif key.startswith(OBS_STR):
|
||||
|
||||
Reference in New Issue
Block a user