mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
feat(audio in policies): adding audio as a input feature in policies
This commit is contained in:
@@ -151,6 +151,10 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
return {}
|
return {}
|
||||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audio_features(self) -> dict[str, PolicyFeature]:
|
||||||
|
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.AUDIO}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_feature(self) -> PolicyFeature | None:
|
def action_feature(self) -> PolicyFeature | None:
|
||||||
if not self.output_features:
|
if not self.output_features:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from enum import Enum
|
|||||||
class FeatureType(str, Enum):
|
class FeatureType(str, Enum):
|
||||||
STATE = "STATE"
|
STATE = "STATE"
|
||||||
VISUAL = "VISUAL"
|
VISUAL = "VISUAL"
|
||||||
|
AUDIO = "AUDIO"
|
||||||
ENV = "ENV"
|
ENV = "ENV"
|
||||||
ACTION = "ACTION"
|
ACTION = "ACTION"
|
||||||
REWARD = "REWARD"
|
REWARD = "REWARD"
|
||||||
|
|||||||
@@ -741,6 +741,10 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
|
elif ft["dtype"] == "audio":
|
||||||
|
type = FeatureType.AUDIO
|
||||||
|
if len(shape) != 2:
|
||||||
|
raise ValueError(f"Number of dimensions of {key} != 2 (shape={shape})")
|
||||||
elif key == OBS_ENV_STATE:
|
elif key == OBS_ENV_STATE:
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
elif key.startswith(OBS_STR):
|
elif key.startswith(OBS_STR):
|
||||||
|
|||||||
Reference in New Issue
Block a user