mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-24 19:57:27 +00:00
Add more tests
This commit is contained in:
@@ -217,9 +217,10 @@ class RTCProcessor:
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
squared_one_minus_tau = (1 - tau) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau) / tau, posinf=max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user