mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 16:27:04 +00:00
fixing final upload to hub
This commit is contained in:
@@ -303,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self,
|
||||
cfg: TrainPipelineConfig,
|
||||
peft_model=None,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
@@ -320,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
peft_model.save_pretrained(saved_path)
|
||||
self.config.save_pretrained(saved_path)
|
||||
else:
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
# Calls _save_pretrained and stores model tensors
|
||||
self.save_pretrained(saved_path, state_dict=state_dict)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
|
||||
@@ -644,6 +644,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
@@ -653,7 +655,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
else:
|
||||
unwrapped_model.push_model_to_hub(cfg)
|
||||
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user