diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 724f920f3..cec2f8f25 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def generate_model_card( self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None ) -> ModelCard: - base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model + base_model_mapping = { + "smolvla": "lerobot/smolvla_base", + "pi0": "lerobot/pi0_base", + "pi05": "lerobot/pi05_base", + "pi0_fast": "lerobot/pi0fast-base", + "xvla": "lerobot/xvla-base", + } card_data = ModelCardData( license=license or "apache-2.0", @@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): tags=list(set(tags or []).union({"robotics", "lerobot", model_type})), model_name=model_type, datasets=dataset_repo_id, - base_model=base_model, + base_model=base_model_mapping(model_type, None), ) template_card = (