mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
Formatting
This commit is contained in:
@@ -461,7 +461,9 @@ def demo_cli(cfg: RTCDemoConfig):
|
|||||||
peft_pretrained_path = cfg.policy.pretrained_path
|
peft_pretrained_path = cfg.policy.pretrained_path
|
||||||
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
|
||||||
|
|
||||||
policy = policy_class.from_pretrained(pretrained_name_or_path=peft_config.base_model_name_or_path, config=config)
|
policy = policy_class.from_pretrained(
|
||||||
|
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
|
||||||
|
)
|
||||||
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
|
||||||
else:
|
else:
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||||
|
|||||||
@@ -493,7 +493,6 @@ def make_policy(
|
|||||||
# Make a fresh policy.
|
# Make a fresh policy.
|
||||||
policy = policy_cls(**kwargs)
|
policy = policy_cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
assert isinstance(policy, torch.nn.Module)
|
assert isinstance(policy, torch.nn.Module)
|
||||||
|
|
||||||
|
|||||||
@@ -289,7 +289,6 @@ def eval_policy(
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
|
|||||||
+1
-7
@@ -121,13 +121,7 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
|
|||||||
model_id = resolve_model_id_for_peft_training(policy_type)
|
model_id = resolve_model_id_for_peft_training(policy_type)
|
||||||
|
|
||||||
def dummy_update_policy(
|
def dummy_update_policy(
|
||||||
train_metrics,
|
train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs
|
||||||
policy,
|
|
||||||
batch,
|
|
||||||
optimizer,
|
|
||||||
grad_clip_norm: float,
|
|
||||||
accelerator,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
params_total = sum(p.numel() for p in policy.parameters())
|
params_total = sum(p.numel() for p in policy.parameters())
|
||||||
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
|||||||
Reference in New Issue
Block a user