From 99c0d93b3412a702ce56d798448b97b17686ba6a Mon Sep 17 00:00:00 2001 From: Nikodem Bartnik Date: Wed, 20 May 2026 12:24:32 +0200 Subject: [PATCH] add more base models to generate model card --- src/lerobot/policies/pretrained.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 724f920f3..cec2f8f25 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def generate_model_card( self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None ) -> ModelCard: - base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model + base_model_mapping = { + "smolvla": "lerobot/smolvla_base", + "pi0": "lerobot/pi0_base", + "pi05": "lerobot/pi05_base", + "pi0_fast": "lerobot/pi0fast-base", + "xvla": "lerobot/xvla-base", + } card_data = ModelCardData( license=license or "apache-2.0", @@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): tags=list(set(tags or []).union({"robotics", "lerobot", model_type})), model_name=model_type, datasets=dataset_repo_id, - base_model=base_model, + base_model=base_model_mapping(model_type, None), ) template_card = (