fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

This commit is contained in:
Eugene Mironov
2025-11-12 00:55:13 +07:00
parent 5ff66e498f
commit 9a38c5f4d2
3 changed files with 2 additions and 317 deletions
-62
View File
@@ -75,68 +75,6 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def tensor_stats_str(tensor: Tensor | None, name: str = "tensor") -> str:
"""Generate readable statistics string for a tensor."""
if tensor is None:
return f"{name}: None"
stats = (
f"{name}:\n"
f" shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}\n"
f" min={tensor.min().item():.6f}, max={tensor.max().item():.6f}\n"
f" mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}"
)
return stats
def compare_tensors(tensor1: Tensor, tensor2: Tensor, name1: str = "tensor1", name2: str = "tensor2") -> str:
"""Compare two tensors and return detailed difference statistics."""
if tensor1 is None or tensor2 is None:
return f"Cannot compare: {name1}={tensor1 is not None}, {name2}={tensor2 is not None}"
# Ensure same shape for comparison
if tensor1.shape != tensor2.shape:
return f"Shape mismatch: {name1}={tuple(tensor1.shape)} vs {name2}={tuple(tensor2.shape)}"
diff = tensor1 - tensor2
abs_diff = torch.abs(diff)
# Per-timestep statistics
if len(diff.shape) >= 2:
# Shape is (batch, time, action_dim) or (time, action_dim)
per_timestep_mean = abs_diff.mean(dim=-1) # Average across action dimensions
timestep_stats = "\n Per-timestep abs diff (averaged across action dims):\n"
if len(per_timestep_mean.shape) > 1:
# Has batch dimension
for batch_idx in range(per_timestep_mean.shape[0]):
timestep_stats += f" Batch {batch_idx}: ["
for t in range(min(10, per_timestep_mean.shape[1])): # Show first 10 timesteps
timestep_stats += f"{per_timestep_mean[batch_idx, t].item():.6f}, "
if per_timestep_mean.shape[1] > 10:
timestep_stats += "..."
timestep_stats += "]\n"
else:
timestep_stats += " ["
for t in range(min(10, len(per_timestep_mean))):
timestep_stats += f"{per_timestep_mean[t].item():.6f}, "
if len(per_timestep_mean) > 10:
timestep_stats += "..."
timestep_stats += "]\n"
else:
timestep_stats = ""
result = (
f"\nDifference: {name1} - {name2}:\n"
f" abs_diff: min={abs_diff.min().item():.6f}, max={abs_diff.max().item():.6f}\n"
f" abs_diff: mean={abs_diff.mean().item():.6f}, std={abs_diff.std().item():.6f}\n"
f" relative_diff: mean={abs_diff.mean().item() / (torch.abs(tensor2).mean().item() + 1e-8) * 100:.2f}%"
f"{timestep_stats}"
)
return result
class RobotWrapper:
def __init__(self, robot: Robot):
self.robot = robot