Add more tests

This commit is contained in:
Eugene Mironov
2025-11-08 17:07:45 +07:00
parent ac33f20e51
commit 99eea2ae03
6 changed files with 2844 additions and 3 deletions
+4 -3
View File
@@ -217,9 +217,10 @@ class RTCProcessor:
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
squared_one_minus_tau = (1 - tau) ** 2
inv_r2 = (squared_one_minus_tau + tau**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau) / tau, posinf=max_guidance_weight)
tau_tensor = torch.as_tensor(tau)
squared_one_minus_tau = (1 - tau_tensor) ** 2
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)