mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
add more base models to generate model card
This commit is contained in:
@@ -248,7 +248,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
def generate_model_card(
|
def generate_model_card(
|
||||||
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None
|
||||||
) -> ModelCard:
|
) -> ModelCard:
|
||||||
base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model
|
base_model_mapping = {
|
||||||
|
"smolvla": "lerobot/smolvla_base",
|
||||||
|
"pi0": "lerobot/pi0_base",
|
||||||
|
"pi05": "lerobot/pi05_base",
|
||||||
|
"pi0_fast": "lerobot/pi0fast-base",
|
||||||
|
"xvla": "lerobot/xvla-base",
|
||||||
|
}
|
||||||
|
|
||||||
card_data = ModelCardData(
|
card_data = ModelCardData(
|
||||||
license=license or "apache-2.0",
|
license=license or "apache-2.0",
|
||||||
@@ -257,7 +263,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
tags=list(set(tags or []).union({"robotics", "lerobot", model_type})),
|
||||||
model_name=model_type,
|
model_name=model_type,
|
||||||
datasets=dataset_repo_id,
|
datasets=dataset_repo_id,
|
||||||
base_model=base_model,
|
base_model=base_model_mapping(model_type, None),
|
||||||
)
|
)
|
||||||
|
|
||||||
template_card = (
|
template_card = (
|
||||||
|
|||||||
Reference in New Issue
Block a user