diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index c4ca35b72..baecd3395 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -230,6 +230,10 @@ def validate_visual_features_consistency( ) -> None: """ Validates visual feature consistency between a policy config and provided dataset/environment features. + + Validation passes if EITHER: + - Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more) + - Dataset's provided visuals are a subset of policy (policy declares extras for flexibility) Args: cfg (PreTrainedConfig): The model or policy configuration containing input_features and type. @@ -237,5 +241,11 @@ def validate_visual_features_consistency( """ expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL} provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL} - if not provided_visuals.issubset(expected_visuals): + + # Accept if either direction is a subset + policy_subset_of_dataset = expected_visuals.issubset(provided_visuals) + dataset_subset_of_policy = provided_visuals.issubset(expected_visuals) + + if not (policy_subset_of_dataset or dataset_subset_of_policy): raise_feature_mismatch_error(provided_visuals, expected_visuals) +