mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
change validation
This commit is contained in:
@@ -230,6 +230,10 @@ def validate_visual_features_consistency(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||||
|
|
||||||
|
Validation passes if EITHER:
|
||||||
|
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||||
|
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||||
@@ -237,5 +241,11 @@ def validate_visual_features_consistency(
|
|||||||
"""
|
"""
|
||||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||||
if not provided_visuals.issubset(expected_visuals):
|
|
||||||
|
# Accept if either direction is a subset
|
||||||
|
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||||
|
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||||
|
|
||||||
|
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user