add more base models to generate model card

This commit is contained in:
Nikodem Bartnik
2026-05-20 12:24:32 +02:00
parent c62784e14c
commit 99c0d93b34
+8 -2
View File
@@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
def generate_model_card( def generate_model_card(
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
) -> ModelCard: ) -> ModelCard:
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model base_model_mapping = {
"smolvla": "lerobot/smolvla_base",
"pi0": "lerobot/pi0_base",
"pi05": "lerobot/pi05_base",
"pi0_fast": "lerobot/pi0fast-base",
"xvla": "lerobot/xvla-base",
}
card_data = ModelCardData( card_data = ModelCardData(
license=license or "apache-2.0", license=license or "apache-2.0",
@@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})), tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
model_name=model_type, model_name=model_type,
datasets=dataset_repo_id, datasets=dataset_repo_id,
base_model=base_model, base_model=base_model_mapping(model_type, None),
) )
template_card = ( template_card = (