mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
add more base models to generate model card
This commit is contained in:
@@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
def generate_model_card(
|
||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||
) -> ModelCard:
|
||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
||||
base_model_mapping = {
|
||||
"smolvla": "lerobot/smolvla_base",
|
||||
"pi0": "lerobot/pi0_base",
|
||||
"pi05": "lerobot/pi05_base",
|
||||
"pi0_fast": "lerobot/pi0fast-base",
|
||||
"xvla": "lerobot/xvla-base",
|
||||
}
|
||||
|
||||
card_data = ModelCardData(
|
||||
license=license or "apache-2.0",
|
||||
@@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||
model_name=model_type,
|
||||
datasets=dataset_repo_id,
|
||||
base_model=base_model,
|
||||
base_model=base_model_mapping(model_type, None),
|
||||
)
|
||||
|
||||
template_card = (
|
||||
|
||||
Reference in New Issue
Block a user