Formatting

This commit is contained in:
nemo
2025-12-20 19:18:56 +01:00
parent 43ca1dc216
commit 4bc75776f7
4 changed files with 4 additions and 10 deletions
+3 -1
View File
@@ -461,7 +461,9 @@ def demo_cli(cfg: RTCDemoConfig):
peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(pretrained_name_or_path=peft_config.base_model_name_or_path, config=config)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
-1
View File
@@ -493,7 +493,6 @@ def make_policy(
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module)
-1
View File
@@ -289,7 +289,6 @@ def eval_policy(
except ImportError:
raise exc
start = time.time()
policy.eval()
+1 -7
View File
@@ -121,13 +121,7 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
model_id = resolve_model_id_for_peft_training(policy_type)
def dummy_update_policy(
train_metrics,
policy,
batch,
optimizer,
grad_clip_norm: float,
accelerator,
**kwargs
train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs
):
params_total = sum(p.numel() for p in policy.parameters())
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)