Formatting

This commit is contained in:
nemo
2025-12-20 19:18:56 +01:00
parent 43ca1dc216
commit 4bc75776f7
4 changed files with 4 additions and 10 deletions
+3 -1
View File
@@ -461,7 +461,9 @@ def demo_cli(cfg: RTCDemoConfig):
peft_pretrained_path = cfg.policy.pretrained_path peft_pretrained_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_pretrained_path) peft_config = PeftConfig.from_pretrained(peft_pretrained_path)
policy = policy_class.from_pretrained(pretrained_name_or_path=peft_config.base_model_name_or_path, config=config) policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=config
)
policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config)
else: else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
-1
View File
@@ -493,7 +493,6 @@ def make_policy(
# Make a fresh policy. # Make a fresh policy.
policy = policy_cls(**kwargs) policy = policy_cls(**kwargs)
policy.to(cfg.device) policy.to(cfg.device)
assert isinstance(policy, torch.nn.Module) assert isinstance(policy, torch.nn.Module)
-1
View File
@@ -289,7 +289,6 @@ def eval_policy(
except ImportError: except ImportError:
raise exc raise exc
start = time.time() start = time.time()
policy.eval() policy.eval()
+1 -7
View File
@@ -121,13 +121,7 @@ def test_peft_training_params_are_fewer(policy_type, tmp_path):
model_id = resolve_model_id_for_peft_training(policy_type) model_id = resolve_model_id_for_peft_training(policy_type)
def dummy_update_policy( def dummy_update_policy(
train_metrics, train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs
policy,
batch,
optimizer,
grad_clip_norm: float,
accelerator,
**kwargs
): ):
params_total = sum(p.numel() for p in policy.parameters()) params_total = sum(p.numel() for p in policy.parameters())
params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)