mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization
This commit is contained in:
@@ -75,68 +75,6 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def tensor_stats_str(tensor: Tensor | None, name: str = "tensor") -> str:
|
|
||||||
"""Generate readable statistics string for a tensor."""
|
|
||||||
if tensor is None:
|
|
||||||
return f"{name}: None"
|
|
||||||
|
|
||||||
stats = (
|
|
||||||
f"{name}:\n"
|
|
||||||
f" shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}\n"
|
|
||||||
f" min={tensor.min().item():.6f}, max={tensor.max().item():.6f}\n"
|
|
||||||
f" mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}"
|
|
||||||
)
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
def compare_tensors(tensor1: Tensor, tensor2: Tensor, name1: str = "tensor1", name2: str = "tensor2") -> str:
|
|
||||||
"""Compare two tensors and return detailed difference statistics."""
|
|
||||||
if tensor1 is None or tensor2 is None:
|
|
||||||
return f"Cannot compare: {name1}={tensor1 is not None}, {name2}={tensor2 is not None}"
|
|
||||||
|
|
||||||
# Ensure same shape for comparison
|
|
||||||
if tensor1.shape != tensor2.shape:
|
|
||||||
return f"Shape mismatch: {name1}={tuple(tensor1.shape)} vs {name2}={tuple(tensor2.shape)}"
|
|
||||||
|
|
||||||
diff = tensor1 - tensor2
|
|
||||||
abs_diff = torch.abs(diff)
|
|
||||||
|
|
||||||
# Per-timestep statistics
|
|
||||||
if len(diff.shape) >= 2:
|
|
||||||
# Shape is (batch, time, action_dim) or (time, action_dim)
|
|
||||||
per_timestep_mean = abs_diff.mean(dim=-1) # Average across action dimensions
|
|
||||||
|
|
||||||
timestep_stats = "\n Per-timestep abs diff (averaged across action dims):\n"
|
|
||||||
if len(per_timestep_mean.shape) > 1:
|
|
||||||
# Has batch dimension
|
|
||||||
for batch_idx in range(per_timestep_mean.shape[0]):
|
|
||||||
timestep_stats += f" Batch {batch_idx}: ["
|
|
||||||
for t in range(min(10, per_timestep_mean.shape[1])): # Show first 10 timesteps
|
|
||||||
timestep_stats += f"{per_timestep_mean[batch_idx, t].item():.6f}, "
|
|
||||||
if per_timestep_mean.shape[1] > 10:
|
|
||||||
timestep_stats += "..."
|
|
||||||
timestep_stats += "]\n"
|
|
||||||
else:
|
|
||||||
timestep_stats += " ["
|
|
||||||
for t in range(min(10, len(per_timestep_mean))):
|
|
||||||
timestep_stats += f"{per_timestep_mean[t].item():.6f}, "
|
|
||||||
if len(per_timestep_mean) > 10:
|
|
||||||
timestep_stats += "..."
|
|
||||||
timestep_stats += "]\n"
|
|
||||||
else:
|
|
||||||
timestep_stats = ""
|
|
||||||
|
|
||||||
result = (
|
|
||||||
f"\nDifference: {name1} - {name2}:\n"
|
|
||||||
f" abs_diff: min={abs_diff.min().item():.6f}, max={abs_diff.max().item():.6f}\n"
|
|
||||||
f" abs_diff: mean={abs_diff.mean().item():.6f}, std={abs_diff.std().item():.6f}\n"
|
|
||||||
f" relative_diff: mean={abs_diff.mean().item() / (torch.abs(tensor2).mean().item() + 1e-8) * 100:.2f}%"
|
|
||||||
f"{timestep_stats}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class RobotWrapper:
|
class RobotWrapper:
|
||||||
def __init__(self, robot: Robot):
|
def __init__(self, robot: Robot):
|
||||||
self.robot = robot
|
self.robot = robot
|
||||||
|
|||||||
@@ -34,77 +34,6 @@ from lerobot.utils.random_utils import set_seed # noqa: E402
|
|||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def validate_rtc_behavior(
|
|
||||||
rtc_actions: torch.Tensor,
|
|
||||||
no_rtc_actions: torch.Tensor,
|
|
||||||
prev_chunk: torch.Tensor,
|
|
||||||
inference_delay: int,
|
|
||||||
execution_horizon: int,
|
|
||||||
rtol: float = 1e-2,
|
|
||||||
):
|
|
||||||
"""Validate RTC behavior follows expected rules.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (all_passed, failures) where failures is a list of error messages
|
|
||||||
"""
|
|
||||||
# Remove batch dimension if present and move to CPU
|
|
||||||
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
|
||||||
no_rtc_actions_t = (
|
|
||||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
|
||||||
)
|
|
||||||
prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu()
|
|
||||||
|
|
||||||
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0])
|
|
||||||
failures = []
|
|
||||||
|
|
||||||
# Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk
|
|
||||||
if inference_delay > 0:
|
|
||||||
delay_end = min(inference_delay, chunk_len)
|
|
||||||
rtc_delay = rtc_actions_t[:delay_end]
|
|
||||||
prev_delay = prev_chunk_t[:delay_end]
|
|
||||||
|
|
||||||
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
|
|
||||||
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
|
|
||||||
failures.append(
|
|
||||||
f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rule 2: Blend region [inference_delay:execution_horizon]
|
|
||||||
blend_start = inference_delay
|
|
||||||
blend_end = min(execution_horizon, chunk_len)
|
|
||||||
|
|
||||||
if blend_end > blend_start:
|
|
||||||
rtc_blend = rtc_actions_t[blend_start:blend_end]
|
|
||||||
prev_blend = prev_chunk_t[blend_start:blend_end]
|
|
||||||
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
|
|
||||||
|
|
||||||
min_bound = torch.minimum(prev_blend, no_rtc_blend)
|
|
||||||
max_bound = torch.maximum(prev_blend, no_rtc_blend)
|
|
||||||
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
|
||||||
|
|
||||||
if not torch.all(within_bounds):
|
|
||||||
violations = torch.sum(~within_bounds).item()
|
|
||||||
total_elements = within_bounds.numel()
|
|
||||||
failures.append(
|
|
||||||
f"Blend region [{blend_start}:{blend_end}]: "
|
|
||||||
f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc
|
|
||||||
if execution_horizon < chunk_len:
|
|
||||||
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
|
|
||||||
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
|
|
||||||
|
|
||||||
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
|
|
||||||
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
|
|
||||||
failures.append(
|
|
||||||
f"Post-horizon [{execution_horizon}:{chunk_len}]: "
|
|
||||||
f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(failures) == 0, failures
|
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
def test_pi05_rtc_initialization():
|
def test_pi05_rtc_initialization():
|
||||||
"""Test PI0.5 policy can initialize RTC processor."""
|
"""Test PI0.5 policy can initialize RTC processor."""
|
||||||
@@ -404,20 +333,4 @@ def test_pi05_rtc_validation_rules():
|
|||||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||||
policy.config.rtc_config.enabled = True
|
policy.config.rtc_config.enabled = True
|
||||||
|
|
||||||
# Validate RTC behavior rules
|
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||||
all_passed, failures = validate_rtc_behavior(
|
|
||||||
rtc_actions=actions_with_rtc,
|
|
||||||
no_rtc_actions=actions_without_rtc,
|
|
||||||
prev_chunk=prev_chunk,
|
|
||||||
inference_delay=inference_delay,
|
|
||||||
execution_horizon=execution_horizon,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all_passed:
|
|
||||||
error_msg = "RTC validation failed:\n" + "\n".join(failures)
|
|
||||||
pytest.fail(error_msg)
|
|
||||||
|
|
||||||
print("✓ PI0.5 RTC validation rules: All rules passed")
|
|
||||||
print(" ✓ Delay region [0:4]: RTC = prev_chunk")
|
|
||||||
print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc")
|
|
||||||
print(" ✓ Post-horizon [10:]: RTC = no_rtc")
|
|
||||||
|
|||||||
@@ -28,77 +28,6 @@ from lerobot.utils.random_utils import set_seed # noqa: E402
|
|||||||
from tests.utils import require_cuda # noqa: E402
|
from tests.utils import require_cuda # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def validate_rtc_behavior(
|
|
||||||
rtc_actions: torch.Tensor,
|
|
||||||
no_rtc_actions: torch.Tensor,
|
|
||||||
prev_chunk: torch.Tensor,
|
|
||||||
inference_delay: int,
|
|
||||||
execution_horizon: int,
|
|
||||||
rtol: float = 1e-2,
|
|
||||||
):
|
|
||||||
"""Validate RTC behavior follows expected rules.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (all_passed, failures) where failures is a list of error messages
|
|
||||||
"""
|
|
||||||
# Remove batch dimension if present and move to CPU
|
|
||||||
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
|
||||||
no_rtc_actions_t = (
|
|
||||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
|
||||||
)
|
|
||||||
prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu()
|
|
||||||
|
|
||||||
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0])
|
|
||||||
failures = []
|
|
||||||
|
|
||||||
# Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk
|
|
||||||
if inference_delay > 0:
|
|
||||||
delay_end = min(inference_delay, chunk_len)
|
|
||||||
rtc_delay = rtc_actions_t[:delay_end]
|
|
||||||
prev_delay = prev_chunk_t[:delay_end]
|
|
||||||
|
|
||||||
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
|
|
||||||
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
|
|
||||||
failures.append(
|
|
||||||
f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rule 2: Blend region [inference_delay:execution_horizon]
|
|
||||||
blend_start = inference_delay
|
|
||||||
blend_end = min(execution_horizon, chunk_len)
|
|
||||||
|
|
||||||
if blend_end > blend_start:
|
|
||||||
rtc_blend = rtc_actions_t[blend_start:blend_end]
|
|
||||||
prev_blend = prev_chunk_t[blend_start:blend_end]
|
|
||||||
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
|
|
||||||
|
|
||||||
min_bound = torch.minimum(prev_blend, no_rtc_blend)
|
|
||||||
max_bound = torch.maximum(prev_blend, no_rtc_blend)
|
|
||||||
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
|
||||||
|
|
||||||
if not torch.all(within_bounds):
|
|
||||||
violations = torch.sum(~within_bounds).item()
|
|
||||||
total_elements = within_bounds.numel()
|
|
||||||
failures.append(
|
|
||||||
f"Blend region [{blend_start}:{blend_end}]: "
|
|
||||||
f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc
|
|
||||||
if execution_horizon < chunk_len:
|
|
||||||
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
|
|
||||||
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
|
|
||||||
|
|
||||||
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
|
|
||||||
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
|
|
||||||
failures.append(
|
|
||||||
f"Post-horizon [{execution_horizon}:{chunk_len}]: "
|
|
||||||
f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
return len(failures) == 0, failures
|
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
@require_cuda
|
||||||
def test_smolvla_rtc_initialization():
|
def test_smolvla_rtc_initialization():
|
||||||
"""Test SmolVLA policy can initialize RTC processor."""
|
"""Test SmolVLA policy can initialize RTC processor."""
|
||||||
@@ -377,99 +306,4 @@ def test_smolvla_rtc_validation_rules():
|
|||||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||||
policy.config.rtc_config.enabled = True
|
policy.config.rtc_config.enabled = True
|
||||||
|
|
||||||
# Validate RTC behavior rules
|
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||||
all_passed, failures = validate_rtc_behavior(
|
|
||||||
rtc_actions=actions_with_rtc,
|
|
||||||
no_rtc_actions=actions_without_rtc,
|
|
||||||
prev_chunk=prev_chunk,
|
|
||||||
inference_delay=inference_delay,
|
|
||||||
execution_horizon=execution_horizon,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not all_passed:
|
|
||||||
error_msg = "RTC validation failed:\n" + "\n".join(failures)
|
|
||||||
pytest.fail(error_msg)
|
|
||||||
|
|
||||||
print("✓ SmolVLA RTC validation rules: All rules passed")
|
|
||||||
print(" ✓ Delay region [0:4]: RTC = prev_chunk")
|
|
||||||
print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc")
|
|
||||||
print(" ✓ Post-horizon [10:]: RTC = no_rtc")
|
|
||||||
|
|
||||||
|
|
||||||
@require_cuda
|
|
||||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
|
||||||
def test_smolvla_rtc_different_schedules():
|
|
||||||
"""Test SmolVLA with different RTC attention schedules."""
|
|
||||||
set_seed(42)
|
|
||||||
|
|
||||||
schedules = [
|
|
||||||
RTCAttentionSchedule.ZEROS,
|
|
||||||
RTCAttentionSchedule.ONES,
|
|
||||||
RTCAttentionSchedule.LINEAR,
|
|
||||||
RTCAttentionSchedule.EXP,
|
|
||||||
]
|
|
||||||
|
|
||||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
|
||||||
|
|
||||||
config.input_features = {
|
|
||||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
||||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
||||||
}
|
|
||||||
config.output_features = {
|
|
||||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create dataset stats
|
|
||||||
dataset_stats = {
|
|
||||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
|
||||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
|
||||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
|
||||||
}
|
|
||||||
|
|
||||||
device = config.device
|
|
||||||
|
|
||||||
for schedule in schedules:
|
|
||||||
print(f"Testing schedule: {schedule}")
|
|
||||||
|
|
||||||
# Add RTC config with specific schedule
|
|
||||||
config.rtc_config = RTCConfig(
|
|
||||||
enabled=True,
|
|
||||||
execution_horizon=10,
|
|
||||||
max_guidance_weight=5.0,
|
|
||||||
prefix_attention_schedule=schedule,
|
|
||||||
debug=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Instantiate policy
|
|
||||||
policy = SmolVLAPolicy(config)
|
|
||||||
policy.eval()
|
|
||||||
preprocessor, _ = make_pre_post_processors(
|
|
||||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create dummy batch
|
|
||||||
batch = {
|
|
||||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
|
||||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
|
||||||
"task": ["Pick up the object"],
|
|
||||||
}
|
|
||||||
batch = preprocessor(batch)
|
|
||||||
|
|
||||||
# Create previous chunk
|
|
||||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
|
||||||
actions = policy.predict_action_chunk(
|
|
||||||
batch,
|
|
||||||
noise=noise,
|
|
||||||
prev_chunk_left_over=prev_chunk,
|
|
||||||
inference_delay=4,
|
|
||||||
execution_horizon=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify shape
|
|
||||||
assert actions.shape == (1, config.chunk_size, 7)
|
|
||||||
print(f" ✓ Schedule {schedule}: Test passed")
|
|
||||||
|
|
||||||
print("✓ SmolVLA RTC different schedules: All schedules tested")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user