mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()
This commit is contained in:
@@ -34,77 +34,6 @@ from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
def validate_rtc_behavior(
|
||||
rtc_actions: torch.Tensor,
|
||||
no_rtc_actions: torch.Tensor,
|
||||
prev_chunk: torch.Tensor,
|
||||
inference_delay: int,
|
||||
execution_horizon: int,
|
||||
rtol: float = 1e-1,
|
||||
):
|
||||
"""Validate RTC behavior follows expected rules.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_passed, failures) where failures is a list of error messages
|
||||
"""
|
||||
# Remove batch dimension if present and move to CPU
|
||||
rtc_actions_t = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||
no_rtc_actions_t = (
|
||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||
)
|
||||
prev_chunk_t = prev_chunk.squeeze(0).cpu() if len(prev_chunk.shape) == 3 else prev_chunk.cpu()
|
||||
|
||||
chunk_len = min(rtc_actions_t.shape[0], no_rtc_actions_t.shape[0], prev_chunk_t.shape[0])
|
||||
failures = []
|
||||
|
||||
# Rule 1: Delay region [0:inference_delay] - RTC should equal prev_chunk
|
||||
if inference_delay > 0:
|
||||
delay_end = min(inference_delay, chunk_len)
|
||||
rtc_delay = rtc_actions_t[:delay_end]
|
||||
prev_delay = prev_chunk_t[:delay_end]
|
||||
|
||||
if not torch.allclose(rtc_delay, prev_delay, rtol=rtol):
|
||||
max_diff = torch.max(torch.abs(rtc_delay - prev_delay)).item()
|
||||
failures.append(
|
||||
f"Delay region [0:{delay_end}]: RTC does NOT equal prev_chunk (max diff: {max_diff:.6f})"
|
||||
)
|
||||
|
||||
# Rule 2: Blend region [inference_delay:execution_horizon]
|
||||
blend_start = inference_delay
|
||||
blend_end = min(execution_horizon, chunk_len)
|
||||
|
||||
if blend_end > blend_start:
|
||||
rtc_blend = rtc_actions_t[blend_start:blend_end]
|
||||
prev_blend = prev_chunk_t[blend_start:blend_end]
|
||||
no_rtc_blend = no_rtc_actions_t[blend_start:blend_end]
|
||||
|
||||
min_bound = torch.minimum(prev_blend, no_rtc_blend)
|
||||
max_bound = torch.maximum(prev_blend, no_rtc_blend)
|
||||
within_bounds = torch.logical_and(rtc_blend >= min_bound, rtc_blend <= max_bound)
|
||||
|
||||
if not torch.all(within_bounds):
|
||||
violations = torch.sum(~within_bounds).item()
|
||||
total_elements = within_bounds.numel()
|
||||
failures.append(
|
||||
f"Blend region [{blend_start}:{blend_end}]: "
|
||||
f"RTC is NOT between prev_chunk and no_rtc ({violations}/{total_elements} violations)"
|
||||
)
|
||||
|
||||
# Rule 3: Post-horizon [execution_horizon:] - RTC should equal no_rtc
|
||||
if execution_horizon < chunk_len:
|
||||
rtc_after = rtc_actions_t[execution_horizon:chunk_len]
|
||||
no_rtc_after = no_rtc_actions_t[execution_horizon:chunk_len]
|
||||
|
||||
if not torch.allclose(rtc_after, no_rtc_after, rtol=rtol):
|
||||
max_diff = torch.max(torch.abs(rtc_after - no_rtc_after)).item()
|
||||
failures.append(
|
||||
f"Post-horizon [{execution_horizon}:{chunk_len}]: "
|
||||
f"RTC does NOT equal no_rtc (max diff: {max_diff:.6f})"
|
||||
)
|
||||
|
||||
return len(failures) == 0, failures
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_initialization():
|
||||
"""Test PI0 policy can initialize RTC processor."""
|
||||
@@ -373,23 +302,7 @@ def test_pi0_rtc_validation_rules():
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Validate RTC behavior rules
|
||||
all_passed, failures = validate_rtc_behavior(
|
||||
rtc_actions=actions_with_rtc,
|
||||
no_rtc_actions=actions_without_rtc,
|
||||
prev_chunk=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
if not all_passed:
|
||||
error_msg = "RTC validation failed:\n" + "\n".join(failures)
|
||||
pytest.fail(error_msg)
|
||||
|
||||
print("✓ PI0 RTC validation rules: All rules passed")
|
||||
print(" ✓ Delay region [0:4]: RTC = prev_chunk")
|
||||
print(" ✓ Blend region [4:10]: prev_chunk ≤ RTC ≤ no_rtc")
|
||||
print(" ✓ Post-horizon [10:]: RTC = no_rtc")
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
"""Test PI0 with different RTC attention schedules."""
|
||||
set_seed(42)
|
||||
|
||||
Reference in New Issue
Block a user