mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
align molmoact2 feature validation with eo pattern
This commit is contained in:
@@ -334,7 +334,24 @@ class MolmoAct2Config(PreTrainedConfig):
|
||||
self.dataset_feature_names[key] = feature["names"]
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up MolmoAct2 input and output features."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
if not image_features:
|
||||
raise ValueError(
|
||||
"MolmoAct2 policy requires at least one visual input feature. "
|
||||
"No features of type FeatureType.VISUAL found in input_features."
|
||||
)
|
||||
|
||||
if OBS_STATE not in self.input_features:
|
||||
self.input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(0,))
|
||||
state_feature = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(0,),
|
||||
)
|
||||
self.input_features[OBS_STATE] = state_feature
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
self.output_features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(0,))
|
||||
action_feature = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(self.expected_max_action_dim,),
|
||||
)
|
||||
self.output_features[ACTION] = action_feature
|
||||
|
||||
Reference in New Issue
Block a user