allow different state dim and action dim

This commit is contained in:
Maximellerbach
2026-05-18 14:32:32 +02:00
parent 1642a14ae5
commit 4e85e23350
@@ -347,6 +347,16 @@ class VLAJEPAPolicy(PreTrainedPolicy):
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
super().__init__(config)
config.validate_features()
if dataset_meta := kwargs.get("dataset_meta"):
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
# compatibility), so validate_features() may have read stale dims from a pretrained
# config. Override state_dim/action_dim from the actual dataset being used.
ds_features = dataset_meta.features
if OBS_STATE in ds_features:
config.state_dim = ds_features[OBS_STATE]["shape"][0]
if ACTION in ds_features:
config.action_dim = ds_features[ACTION]["shape"][0]
self.model = VLAJEPAModel(config)
self.reset()