From 2e0deff3ab1ef3a4358afb830334b60fce18a575 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Wed, 17 Jun 2026 09:42:05 +0000 Subject: [PATCH] fixing final upload to hub --- src/lerobot/policies/pretrained.py | 4 +++- src/lerobot/scripts/lerobot_train.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index eedf9d99e..a7aabb3f3 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -303,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): self, cfg: TrainPipelineConfig, peft_model=None, + state_dict: dict[str, Tensor] | None = None, ): api = HfApi() repo_id = api.create_repo( @@ -320,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): peft_model.save_pretrained(saved_path) self.config.save_pretrained(saved_path) else: - self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + # Calls _save_pretrained and stores model tensors + self.save_pretrained(saved_path, state_dict=state_dict) card = self.generate_model_card( cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index add506887..a235d6248 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -644,6 +644,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if eval_env: close_envs(eval_env) + is_fsdp = accelerator.distributed_type == DistributedType.FSDP + model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None if is_main_process: logging.info("End of training") @@ -653,7 +655,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if not cfg.is_reward_model_training and cfg.policy.use_peft: unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model) else: - unwrapped_model.push_model_to_hub(cfg) + unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict) preprocessor.push_to_hub(active_cfg.repo_id) postprocessor.push_to_hub(active_cfg.repo_id)