diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index bc25b9f09..1f449f9c8 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -75,68 +75,6 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def tensor_stats_str(tensor: Tensor | None, name: str = "tensor") -> str: - """Generate readable statistics string for a tensor.""" - if tensor is None: - return f"{name}: None" - - stats = ( - f"{name}:\n" - f" shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}\n" - f" min={tensor.min().item():.6f}, max={tensor.max().item():.6f}\n" - f" mean={tensor.mean().item():.6f}, std={tensor.std().item():.6f}" - ) - return stats - - -def compare_tensors(tensor1: Tensor, tensor2: Tensor, name1: str = "tensor1", name2: str = "tensor2") -> str: - """Compare two tensors and return detailed difference statistics.""" - if tensor1 is None or tensor2 is None: - return f"Cannot compare: {name1}={tensor1 is not None}, {name2}={tensor2 is not None}" - - # Ensure same shape for comparison - if tensor1.shape != tensor2.shape: - return f"Shape mismatch: {name1}={tuple(tensor1.shape)} vs {name2}={tuple(tensor2.shape)}" - - diff = tensor1 - tensor2 - abs_diff = torch.abs(diff) - - # Per-timestep statistics - if len(diff.shape) >= 2: - # Shape is (batch, time, action_dim) or (time, action_dim) - per_timestep_mean = abs_diff.mean(dim=-1) # Average across action dimensions - - timestep_stats = "\n Per-timestep abs diff (averaged across action dims):\n" - if len(per_timestep_mean.shape) > 1: - # Has batch dimension - for batch_idx in range(per_timestep_mean.shape[0]): - timestep_stats += f" Batch {batch_idx}: [" - for t in range(min(10, per_timestep_mean.shape[1])): # Show first 10 timesteps - timestep_stats += f"{per_timestep_mean[batch_idx, t].item():.6f}, " - if per_timestep_mean.shape[1] > 10: - timestep_stats += "..." - timestep_stats += "]\n" - else: - timestep_stats += " [" - for t in range(min(10, len(per_timestep_mean))): - timestep_stats += f"{per_timestep_mean[t].item():.6f}, " - if len(per_timestep_mean) > 10: - timestep_stats += "..." - timestep_stats += "]\n" - else: - timestep_stats = "" - - result = ( - f"\nDifference: {name1} - {name2}:\n" - f" abs_diff: min={abs_diff.min().item():.6f}, max={abs_diff.max().item():.6f}\n" - f" abs_diff: mean={abs_diff.mean().item():.6f}, std={abs_diff.std().item():.6f}\n" - f" relative_diff: mean={abs_diff.mean().item() / (torch.abs(tensor2).mean().item() + 1e-8) * 100:.2f}%" - f"{timestep_stats}" - ) - - return result - - class RobotWrapper: def __init__(self, robot: Robot): self.robot = robot diff --git a/tests/policies/pi0_pi05/test_pi05_rtc.py b/tests/policies/pi0_pi05/test_pi05_rtc.py index b58ff49dc..3a753031f 100644 --- a/tests/policies/pi0_pi05/test_pi05_rtc.py +++ b/tests/policies/pi0_pi05/test_pi05_rtc.py @@ -34,77 +34,6 @@ from lerobot.utils.random_utils import set_seed # noqa: E402 from tests.utils import require_cuda # noqa: E402 -def validate_rtc_behavior( - rtc_actions: torch.Tensor, - no_rtc_actions: torch.Tensor, - prev_chunk: torch.Tensor, - inference_delay: int, - execution_horizon: int, - rtol: float = 1e-2, -): - """Validate RTC behavior follows expected rules. - - Returns: - Tuple of (all_passed, failures) where failures is a list of error messages - """ - # Remove batch dimension if present and move to CPU - rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() - no_rtc_actions_t = ( - no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() - ) - prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() - - chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) - failures = [] - - # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk - if inference_delay > 0: - delay_end = min(inference_delay, chunk_len) - rtc_delay = rtc_actions_t[:delay_end] - prev_delay = prev_chunk_t[:delay_end] - - if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() - failures.append( - f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" - ) - - # Rule 2: Blend region [inference_delay:execution_horizon] - blend_start = inference_delay - blend_end = min(execution_horizon, chunk_len) - - if blend_end > blend_start: - rtc_blend = rtc_actions_t[blend_start:blend_end] - prev_blend = prev_chunk_t[blend_start:blend_end] - no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] - - min_bound = torch.minimum(prev_blend, no_rtc_blend) - max_bound = torch.maximum(prev_blend, no_rtc_blend) - within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) - - if not torch.all(within_bounds): - violations = torch.sum(~within_bounds).item() - total_elements = within_bounds.numel() - failures.append( - f"Blend region [{blend_start}:{blend_end}]: " - f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" - ) - - # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc - if execution_horizon < chunk_len: - rtc_after = rtc_actions_t[execution_horizon:chunk_len] - no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] - - if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() - failures.append( - f"Post-horizon [{execution_horizon}:{chunk_len}]: " - f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" - ) - - return len(failures) == 0, failures - - @require_cuda def test_pi05_rtc_initialization(): """Test PI0.5 policy can initialize RTC processor.""" @@ -404,20 +333,4 @@ def test_pi05_rtc_validation_rules(): actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) policy.config.rtc_config.enabled = True - # Validate RTC behavior rules - all_passed, failures = validate_rtc_behavior( - rtc_actions=actions_with_rtc, - no_rtc_actions=actions_without_rtc, - prev_chunk=prev_chunk, - inference_delay=inference_delay, - execution_horizon=execution_horizon, - ) - - if not all_passed: - error_msg = "RTC validation failed:\n" + "\n".join(failures) - pytest.fail(error_msg) - - print("✓ PI0.5 RTC validation rules: All rules passed") - print(" ✓ Delay region [0:4]: RTC = prev_chunk") - print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") - print(" ✓ Post-horizon [10:]: RTC = no_rtc") + assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3) diff --git a/tests/policies/smolvla/test_smolvla_rtc.py b/tests/policies/smolvla/test_smolvla_rtc.py index f14bbf2f6..6888dee8d 100644 --- a/tests/policies/smolvla/test_smolvla_rtc.py +++ b/tests/policies/smolvla/test_smolvla_rtc.py @@ -28,77 +28,6 @@ from lerobot.utils.random_utils import set_seed # noqa: E402 from tests.utils import require_cuda # noqa: E402 -def validate_rtc_behavior( - rtc_actions: torch.Tensor, - no_rtc_actions: torch.Tensor, - prev_chunk: torch.Tensor, - inference_delay: int, - execution_horizon: int, - rtol: float = 1e-2, -): - """Validate RTC behavior follows expected rules. - - Returns: - Tuple of (all_passed, failures) where failures is a list of error messages - """ - # Remove batch dimension if present and move to CPU - rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu() - no_rtc_actions_t = ( - no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu() - ) - prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu() - - chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0]) - failures = [] - - # Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk - if inference_delay > 0: - delay_end = min(inference_delay, chunk_len) - rtc_delay = rtc_actions_t[:delay_end] - prev_delay = prev_chunk_t[:delay_end] - - if not torch.allclose(rtc_delay, prev_delay, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item() - failures.append( - f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})" - ) - - # Rule 2: Blend region [inference_delay:execution_horizon] - blend_start = inference_delay - blend_end = min(execution_horizon, chunk_len) - - if blend_end > blend_start: - rtc_blend = rtc_actions_t[blend_start:blend_end] - prev_blend = prev_chunk_t[blend_start:blend_end] - no_rtc_blend = no_rtc_actions_t[blend_start:blend_end] - - min_bound = torch.minimum(prev_blend, no_rtc_blend) - max_bound = torch.maximum(prev_blend, no_rtc_blend) - within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound) - - if not torch.all(within_bounds): - violations = torch.sum(~within_bounds).item() - total_elements = within_bounds.numel() - failures.append( - f"Blend region [{blend_start}:{blend_end}]: " - f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)" - ) - - # Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc - if execution_horizon < chunk_len: - rtc_after = rtc_actions_t[execution_horizon:chunk_len] - no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len] - - if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol): - max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item() - failures.append( - f"Post-horizon [{execution_horizon}:{chunk_len}]: " - f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})" - ) - - return len(failures) == 0, failures - - @require_cuda def test_smolvla_rtc_initialization(): """Test SmolVLA policy can initialize RTC processor.""" @@ -377,99 +306,4 @@ def test_smolvla_rtc_validation_rules(): actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone()) policy.config.rtc_config.enabled = True - # Validate RTC behavior rules - all_passed, failures = validate_rtc_behavior( - rtc_actions=actions_with_rtc, - no_rtc_actions=actions_without_rtc, - prev_chunk=prev_chunk, - inference_delay=inference_delay, - execution_horizon=execution_horizon, - ) - - if not all_passed: - error_msg = "RTC validation failed:\n" + "\n".join(failures) - pytest.fail(error_msg) - - print("✓ SmolVLA RTC validation rules: All rules passed") - print(" ✓ Delay region [0:4]: RTC = prev_chunk") - print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc") - print(" ✓ Post-horizon [10:]: RTC = no_rtc") - - -@require_cuda -@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights") -def test_smolvla_rtc_different_schedules(): - """Test SmolVLA with different RTC attention schedules.""" - set_seed(42) - - schedules = [ - RTCAttentionSchedule.ZEROS, - RTCAttentionSchedule.ONES, - RTCAttentionSchedule.LINEAR, - RTCAttentionSchedule.EXP, - ] - - config = SmolVLAConfig(max_action_dim=7, chunk_size=50) - - config.input_features = { - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)), - "observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), - } - config.output_features = { - "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), - } - - # Create dataset stats - dataset_stats = { - "observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)}, - "action": {"mean": torch.zeros(7), "std": torch.ones(7)}, - "observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)}, - } - - device = config.device - - for schedule in schedules: - print(f"Testing schedule: {schedule}") - - # Add RTC config with specific schedule - config.rtc_config = RTCConfig( - enabled=True, - execution_horizon=10, - max_guidance_weight=5.0, - prefix_attention_schedule=schedule, - debug=False, - ) - - # Instantiate policy - policy = SmolVLAPolicy(config) - policy.eval() - preprocessor, _ = make_pre_post_processors( - policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats - ) - - # Create dummy batch - batch = { - "observation.state": torch.randn(1, 14, dtype=torch.float32, device=device), - "observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device), - "task": ["Pick up the object"], - } - batch = preprocessor(batch) - - # Create previous chunk - prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device) - - with torch.no_grad(): - noise = policy.model.sample_noise((1, config.chunk_size, 7), device) - actions = policy.predict_action_chunk( - batch, - noise=noise, - prev_chunk_left_over=prev_chunk, - inference_delay=4, - execution_horizon=10, - ) - - # Verify shape - assert actions.shape == (1, config.chunk_size, 7) - print(f" ✓ Schedule {schedule}: Test passed") - - print("✓ SmolVLA RTC different schedules: All schedules tested") + assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)