mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
allow different state dim and action dim
This commit is contained in:
@@ -347,6 +347,16 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
def __init__(self, config: VLAJEPAConfig, **kwargs) -> None:
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
if dataset_meta := kwargs.get("dataset_meta"):
|
||||||
|
# cfg.input_features keeps the pretrained model's feature keys (needed for rename_map
|
||||||
|
# compatibility), so validate_features() may have read stale dims from a pretrained
|
||||||
|
# config. Override state_dim/action_dim from the actual dataset being used.
|
||||||
|
ds_features = dataset_meta.features
|
||||||
|
if OBS_STATE in ds_features:
|
||||||
|
config.state_dim = ds_features[OBS_STATE]["shape"][0]
|
||||||
|
if ACTION in ds_features:
|
||||||
|
config.action_dim = ds_features[ACTION]["shape"][0]
|
||||||
|
|
||||||
self.model = VLAJEPAModel(config)
|
self.model = VLAJEPAModel(config)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user