Files
lerobot/tests/policies/rtc/test_debug_tracker.py
Eugene Mironov 8a915c6b6f [RTC] Real Time Chunking for Pi0, Smolvla, Pi0.5 (#1698)
* Add Real-Time Chunking (RTC) support for flow matching models

Implement Real-Time Chunking (RTC) for action chunking policies using flow
matching denoising. RTC enables smooth action transitions between consecutive
chunks by using prefix guidance during denoising.

Key features:
- RTCProcessor class with denoise_step method for RTC guidance
- Tracker system for debug tracking using time-based dictionary storage
- RTCDebugVisualizer with comprehensive visualization utilities
- Integration with SmolVLA policy for flow matching models
- Support for multiple prefix attention schedules (ZEROS, ONES, LINEAR, EXP)
- Configurable execution horizon and max guidance weight
- Example scripts for dataset evaluation and real-time control

Technical details:
- Uses autograd-based gradient computation for RTC corrections
- Time-based tracking eliminates duplicate step issues
- Proxy methods in RTCProcessor for cleaner API
- Full integration with LeRobot's policy and dataset systems

Files added/modified:
- src/lerobot/configs/types.py: Add RTCAttentionSchedule enum
- src/lerobot/policies/rtc/: Core RTC implementation
  - configuration_rtc.py: RTC configuration
  - modeling_rtc.py: RTCProcessor with denoise_step
  - debug_handler.py: Tracker for debug information
  - debug_visualizer.py: Visualization utilities
- src/lerobot/policies/smolvla/modeling_smolvla.py: RTC integration
- examples/rtc/: Example scripts and evaluation tools

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Fix rtc_config attribute access in SmolVLA

Use getattr() to safely check for rtc_config attribute existence
instead of direct attribute access. This fixes AttributeError when
loading policies without rtc_config in their config.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix rtc_config attribute access in SmolVLA

* Add RTCConfig field to SmolVLAConfig

Add rtc_config as an optional field in SmolVLAConfig to properly
support Real-Time Chunking configuration. This replaces the previous
getattr() workarounds with direct attribute access, making the code
cleaner and more maintainable.

Changes:
- Import RTCConfig in configuration_smolvla.py
- Add rtc_config: RTCConfig | None = None field
- Revert getattr() calls to direct attribute access in modeling_smolvla.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Refactor RTC enabled checks to use _rtc_enabled helper

Add _rtc_enabled() helper method in VLAFlowMatching class to simplify
and clean up RTC enabled checks throughout the code. This reduces
code duplication and improves readability.

Changes:
- Add _rtc_enabled() method in VLAFlowMatching
- Replace verbose rtc_config checks with _rtc_enabled() calls
- Maintain exact same functionality with cleaner code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Rename track_debug method to track

Simplify the method name from track_debug to just track for better
readability and consistency. The method already has clear documentation
about its debug tracking purpose.

Changes:
- Rename RTCProcessor.track_debug() to track()
- Update all call sites in modeling_smolvla.py and modeling_rtc.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* Use output_dir for saving all evaluation images

Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Use output_dir for saving all evaluation images

* Fix logging buffering and enable tracking when RTC config provided

- Add force=True to logging.basicConfig to override existing configuration
- Enable line buffering for stdout/stderr for real-time log output
- Modify init_rtc_processor to create processor when rtc_config exists
  even if RTC is disabled, allowing tracking of denoising data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor SmolVLA plotting to use tracker data instead of local variables

Remove local tracking variables (correction, x1_t, error) from the
denoising loop and instead retrieve plotting data from the RTC tracker
after each denoise step. This makes the code cleaner and uses the
tracker as the single source of truth for debug/visualization data.

Changes:
- Remove initialization of correction, x1_t, error before denoising loop
- After each Euler step, retrieve most recent debug step from tracker
- Extract correction, x1_t, err from debug step for plotting
- Update tracking condition to use is_debug_enabled() method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Move plotting logic from modeling_smolvla to eval_dataset script

Refactor to improve separation of concerns:

modeling_smolvla.py changes:
- Remove all plotting logic from sample_actions method
- Remove viz_xt_axs, viz_vt_axs, viz_x1t_axs parameters
- Remove matplotlib and RTCDebugVisualizer imports
- Remove viz_fig, viz_axs, denoise_step_counter instance variables
- Simplify denoising loop to only track data in rtc_processor

eval_dataset.py changes:
- Add _plot_denoising_steps_from_tracker helper method
- Retrieve debug steps from tracker after inference
- Plot x_t, v_t, x1_t, correction, and error from tracker data
- Enable debug tracking (cfg.rtc.debug = True) for visualization
- Remove viz axes parameters from predict_action_chunk calls

modeling_rtc.py changes:
- Remove v_t from track() call (handled by user change)

Benefits:
- Cleaner modeling code focused on inference
- Evaluation script owns all visualization logic
- Better separation of concerns
- Tracker is single source of truth for debug data

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* Refactor plotting loging

* fixup! Refactor plotting loging

* Improve visualization: separate correction plot and fix axis scaling

Changes:
- Create separate figure for correction data instead of overlaying on v_t
- Add _rescale_axes helper method to properly scale all axes
- Add 10% margin to y-axis for better visualization
- Fix v_t chart vertical compression issue

Benefits:
- Clearer v_t plot without correction overlay
- Better axis scaling with proper margins
- Separate correction figure for focused analysis
- Improved readability of all denoising visualizations

Output files:
- denoising_xt_comparison.png (x_t trajectories)
- denoising_vt_comparison.png (v_t velocity - now cleaner)
- denoising_correction_comparison.png (NEW - separate corrections)
- denoising_x1t_comparison.png (x1_t state with error)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>

* fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* fixup! fixup! fixup! Improve visualization: separate correction plot and fix axis scaling

* Fix traacking

* Right kwargs for the policy

* Add tests for tracker

* Fix tests

* Drop not required methods

* Add torch compilation for eval_dataset

* delete policies

* Add matplotliv to dev

* fixup! Add matplotliv to dev

* Experiemnt with late detach

* Debug

* Fix compilation

* Add RTC to PI0

* Pi0

* Pi0 eval dataset

* fixup! Pi0 eval dataset

* Turn off compilation for pi0/pi05

* fixup! Turn off compilation for pi0/pi05

* fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* fixup! fixup! fixup! fixup! fixup! Turn off compilation for pi0/pi05

* Add workable flow

* Small fixes

* Add more tests

* Add validatio at the end

* Update README

* Silent validation

* Fix tests

* Add tests for modeling_rtc

* Add tests for flow matching models with RTC

* fixup! Add tests for flow matching models with RTC

* fixup! fixup! Add tests for flow matching models with RTC

* Add one more test

* fixup! Add one more test

* Fix test to use _rtc_enabled() instead of is_rtc_enabled()

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* fixup! fixup! Fix test to use _rtc_enabled() instead of is_rtc_enabled()

* Add RTC initialization tests without config for PI0.5 and SmolVLA

Add test_pi05_rtc_initialization_without_rtc_config and
test_smolvla_rtc_initialization_without_rtc_config to verify that
policies can initialize without RTC config and that _rtc_enabled()
returns False in this case.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix SmolVLA init_rtc_processor to use getattr instead of direct model access

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fixup! Fix PI0.5 RTC tests to use quantile stats (q01, q99) for normalization

* Fixup eval with real robot

* fixup! Fixup eval with real robot

* fixup! fixup! Fixup eval with real robot

* Extract simulator logic from eval_with real robot and add proper headers to files

* Update images

* Fix tests

* fixup! Fix tests

* add docs for rtc

* enhance doc and add images

* Fix instal instructions

---------
Co-authored-by: Ben Zhang <benzhangniu@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-11-19 11:19:48 +01:00

489 lines
16 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RTC debug tracker module."""
import pytest
import torch
from lerobot.policies.rtc.debug_tracker import DebugStep, Tracker
# ====================== Fixtures ======================
@pytest.fixture
def sample_tensors():
"""Create sample tensors for testing."""
return {
"x_t": torch.randn(1, 50, 6),
"v_t": torch.randn(1, 50, 6),
"x1_t": torch.randn(1, 50, 6),
"correction": torch.randn(1, 50, 6),
"err": torch.randn(1, 50, 6),
"weights": torch.randn(1, 50, 1),
}
@pytest.fixture
def enabled_tracker():
"""Create an enabled tracker with default settings."""
return Tracker(enabled=True, maxlen=100)
@pytest.fixture
def disabled_tracker():
"""Create a disabled tracker."""
return Tracker(enabled=False)
# ====================== DebugStep Tests ======================
def test_debug_step_initialization():
"""Test that DebugStep can be initialized with default values."""
step = DebugStep()
assert step.step_idx == 0
assert step.x_t is None
assert step.v_t is None
assert step.x1_t is None
assert step.correction is None
assert step.err is None
assert step.weights is None
assert step.guidance_weight is None
assert step.time is None
assert step.inference_delay is None
assert step.execution_horizon is None
assert step.metadata == {}
def test_debug_step_with_values(sample_tensors):
"""Test DebugStep initialization with actual values."""
step = DebugStep(
step_idx=5,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
x1_t=sample_tensors["x1_t"],
correction=sample_tensors["correction"],
err=sample_tensors["err"],
weights=sample_tensors["weights"],
guidance_weight=2.5,
time=0.8,
inference_delay=4,
execution_horizon=8,
metadata={"custom_key": "custom_value"},
)
assert step.step_idx == 5
assert torch.equal(step.x_t, sample_tensors["x_t"])
assert torch.equal(step.v_t, sample_tensors["v_t"])
assert torch.equal(step.x1_t, sample_tensors["x1_t"])
assert torch.equal(step.correction, sample_tensors["correction"])
assert torch.equal(step.err, sample_tensors["err"])
assert torch.equal(step.weights, sample_tensors["weights"])
assert step.guidance_weight == 2.5
assert step.time == 0.8
assert step.inference_delay == 4
assert step.execution_horizon == 8
assert step.metadata == {"custom_key": "custom_value"}
def test_debug_step_to_dict_without_tensors(sample_tensors):
"""Test converting DebugStep to dictionary without tensor values."""
step = DebugStep(
step_idx=3,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=torch.tensor(3.0),
time=torch.tensor(0.5),
inference_delay=2,
execution_horizon=10,
)
result = step.to_dict(include_tensors=False)
assert result["step_idx"] == 3
assert result["guidance_weight"] == 3.0
assert result["time"] == 0.5
assert result["inference_delay"] == 2
assert result["execution_horizon"] == 10
# Check tensor statistics are included
assert "x_t_stats" in result
assert "v_t_stats" in result
assert "x1_t_stats" not in result # x1_t was None
# Verify statistics structure
assert "shape" in result["x_t_stats"]
assert "mean" in result["x_t_stats"]
assert "std" in result["x_t_stats"]
assert "min" in result["x_t_stats"]
assert "max" in result["x_t_stats"]
# Verify shape matches original tensor
assert result["x_t_stats"]["shape"] == tuple(sample_tensors["x_t"].shape)
def test_debug_step_to_dict_with_tensors(sample_tensors):
"""Test converting DebugStep to dictionary with tensor values."""
step = DebugStep(
step_idx=1,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=1.5,
time=0.9,
)
result = step.to_dict(include_tensors=True)
assert result["step_idx"] == 1
assert result["guidance_weight"] == 1.5
assert result["time"] == 0.9
# Check tensors are included (as CPU tensors)
assert "x_t" in result
assert "v_t" in result
assert isinstance(result["x_t"], torch.Tensor)
assert isinstance(result["v_t"], torch.Tensor)
assert result["x_t"].device.type == "cpu"
assert result["v_t"].device.type == "cpu"
def test_debug_step_to_dict_with_none_guidance_weight():
"""Test to_dict handles None guidance_weight correctly."""
step = DebugStep(step_idx=0, time=1.0, guidance_weight=None)
result = step.to_dict(include_tensors=False)
assert result["guidance_weight"] is None
def test_tracker_initialization_enabled():
"""Test tracker initialization when enabled."""
tracker = Tracker(enabled=True, maxlen=50)
assert tracker.enabled is True
assert tracker._steps == {}
assert tracker._maxlen == 50
assert tracker._step_counter == 0
assert len(tracker) == 0
def test_tracker_reset_when_enabled(enabled_tracker, sample_tensors):
"""Test reset clears all steps when tracker is enabled."""
# Add some steps
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
# Reset
enabled_tracker.reset()
assert len(enabled_tracker) == 0
assert enabled_tracker._step_counter == 0
assert enabled_tracker._steps == {}
def test_tracker_reset_when_disabled(disabled_tracker):
"""Test reset on disabled tracker doesn't cause errors."""
disabled_tracker.reset()
assert len(disabled_tracker) == 0
# ====================== Tracker.track() Tests ======================
def test_track_creates_new_step(enabled_tracker, sample_tensors):
"""Test that track creates a new step when time doesn't exist."""
enabled_tracker.track(
time=1.0,
x_t=sample_tensors["x_t"],
v_t=sample_tensors["v_t"],
guidance_weight=5.0,
inference_delay=4,
execution_horizon=8,
)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].step_idx == 0
assert steps[0].time == 1.0
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
assert steps[0].guidance_weight == 5.0
assert steps[0].inference_delay == 4
assert steps[0].execution_horizon == 8
def test_track_updates_existing_step(enabled_tracker, sample_tensors):
"""Test that track updates an existing step at the same time."""
# Create initial step
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert steps[0].v_t is None
# Update the same timestep with v_t
enabled_tracker.track(time=0.9, v_t=sample_tensors["v_t"])
assert len(enabled_tracker) == 1 # Still only one step
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Original x_t preserved
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # New v_t added
def test_track_with_tensor_time(enabled_tracker, sample_tensors):
"""Test track handles tensor time values correctly."""
time_tensor = torch.tensor(0.8)
enabled_tracker.track(time=time_tensor, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert abs(steps[0].time - 0.8) < 1e-6 # Use approximate comparison for floating point
def test_track_time_rounding(enabled_tracker, sample_tensors):
"""Test that track rounds time to avoid floating point precision issues."""
# These times should be treated as the same after rounding to 6 decimals
enabled_tracker.track(time=0.9000001, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9000002, v_t=sample_tensors["v_t"])
# Should still be one step (times rounded to same value)
assert len(enabled_tracker) == 1
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
def test_track_does_nothing_when_disabled(disabled_tracker, sample_tensors):
"""Test that track does nothing when tracker is disabled."""
disabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(disabled_tracker) == 0
def test_track_with_metadata(enabled_tracker, sample_tensors):
"""Test track stores custom metadata."""
enabled_tracker.track(time=0.7, x_t=sample_tensors["x_t"], custom_field="custom_value", count=42)
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["custom_field"] == "custom_value"
assert steps[0].metadata["count"] == 42
def test_track_updates_metadata(enabled_tracker):
"""Test that track updates metadata for existing steps."""
enabled_tracker.track(time=0.6, meta1="value1")
enabled_tracker.track(time=0.6, meta2="value2")
steps = enabled_tracker.get_all_steps()
assert steps[0].metadata["meta1"] == "value1"
assert steps[0].metadata["meta2"] == "value2"
def test_track_clones_tensors(enabled_tracker, sample_tensors):
"""Test that track clones tensors instead of storing references."""
x_t_original = sample_tensors["x_t"].clone()
enabled_tracker.track(time=0.5, x_t=sample_tensors["x_t"])
# Modify original tensor
sample_tensors["x_t"].fill_(999.0)
# Tracked tensor should not be affected
steps = enabled_tracker.get_all_steps()
assert not torch.equal(steps[0].x_t, sample_tensors["x_t"])
assert torch.equal(steps[0].x_t, x_t_original)
def test_track_with_none_values(enabled_tracker):
"""Test track handles None values correctly."""
enabled_tracker.track(
time=0.4,
x_t=None,
v_t=None,
guidance_weight=None,
inference_delay=None,
)
steps = enabled_tracker.get_all_steps()
assert len(steps) == 1
assert steps[0].x_t is None
assert steps[0].v_t is None
assert steps[0].guidance_weight is None
assert steps[0].inference_delay is None
def test_track_updates_only_non_none_fields(enabled_tracker, sample_tensors):
"""Test that update preserves existing values when None is passed."""
# Create step with x_t
enabled_tracker.track(time=0.3, x_t=sample_tensors["x_t"], guidance_weight=2.0)
# Update with v_t only (pass None for other fields)
enabled_tracker.track(time=0.3, v_t=sample_tensors["v_t"], x_t=None, guidance_weight=None)
# Original values should be preserved
steps = enabled_tracker.get_all_steps()
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Still has x_t
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # Now has v_t
assert steps[0].guidance_weight == 2.0 # Still has guidance_weight
# ====================== Tracker.maxlen Tests ======================
def test_tracker_enforces_maxlen():
"""Test that tracker enforces maxlen limit."""
tracker = Tracker(enabled=True, maxlen=3)
# Add 5 steps
for i in range(5):
time = 1.0 - i * 0.1 # 1.0, 0.9, 0.8, 0.7, 0.6
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should only keep the last 3
assert len(tracker) == 3
# Verify oldest steps were removed (should have 0.6, 0.7, 0.8)
steps = tracker.get_all_steps()
times = sorted([step.time for step in steps])
assert times == [0.6, 0.7, 0.8]
def test_tracker_step_idx_increments_despite_maxlen():
"""Test that step_idx continues incrementing even when maxlen is enforced."""
tracker = Tracker(enabled=True, maxlen=2)
# Add 4 steps
for i in range(4):
time = 1.0 - i * 0.1
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
# Should have 2 steps with step_idx 2 and 3 (oldest removed)
steps = sorted(tracker.get_all_steps(), key=lambda s: s.step_idx)
assert len(steps) == 2
assert steps[0].step_idx == 2
assert steps[1].step_idx == 3
def test_tracker_without_maxlen_keeps_all():
"""Test that tracker without maxlen keeps all steps."""
tracker = Tracker(enabled=True, maxlen=None)
# Add 100 steps
for i in range(100):
time = 1.0 - i * 0.01
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
assert len(tracker) == 100
def test_get_all_steps_returns_empty_when_disabled(disabled_tracker):
"""Test get_all_steps returns empty list when disabled."""
steps = disabled_tracker.get_all_steps()
assert steps == []
assert isinstance(steps, list)
def test_get_all_steps_returns_empty_when_no_steps(enabled_tracker):
"""Test get_all_steps returns empty list when no steps tracked."""
steps = enabled_tracker.get_all_steps()
assert steps == []
def test_get_all_steps_returns_all_tracked_steps(enabled_tracker, sample_tensors):
"""Test get_all_steps returns all tracked steps."""
# Track 5 steps
for i in range(5):
time = 1.0 - i * 0.1
enabled_tracker.track(time=time, x_t=sample_tensors["x_t"])
steps = enabled_tracker.get_all_steps()
assert len(steps) == 5
# Verify all are DebugStep instances
for step in steps:
assert isinstance(step, DebugStep)
def test_get_all_steps_preserves_insertion_order(enabled_tracker):
"""Test that get_all_steps preserves insertion order (Python 3.7+)."""
times = [0.9, 0.8, 0.7, 0.6, 0.5]
for time in times:
enabled_tracker.track(time=time, x_t=torch.randn(1, 10, 6))
steps = enabled_tracker.get_all_steps()
retrieved_times = [step.time for step in steps]
# Should be in insertion order
assert retrieved_times == times
# ====================== Tracker.__len__() Tests ======================
def test_len_returns_zero_when_disabled(disabled_tracker):
"""Test __len__ returns 0 when tracker is disabled."""
assert len(disabled_tracker) == 0
def test_len_returns_zero_when_empty(enabled_tracker):
"""Test __len__ returns 0 when no steps are tracked."""
assert len(enabled_tracker) == 0
def test_len_returns_correct_count(enabled_tracker, sample_tensors):
"""Test __len__ returns correct number of tracked steps."""
assert len(enabled_tracker) == 0
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 1
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.track(time=0.8, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 3
def test_len_after_reset(enabled_tracker, sample_tensors):
"""Test __len__ returns 0 after reset."""
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
assert len(enabled_tracker) == 2
enabled_tracker.reset()
assert len(enabled_tracker) == 0
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_tracker_handles_gpu_tensors():
"""Test tracker correctly handles GPU tensors."""
tracker = Tracker(enabled=True, maxlen=10)
x_t_gpu = torch.randn(1, 50, 6, device="cuda")
tracker.track(time=1.0, x_t=x_t_gpu)
steps = tracker.get_all_steps()
# Tracker should clone and detach tensors
assert steps[0].x_t.device.type == "cuda"
def test_tracker_with_varying_tensor_shapes(enabled_tracker):
"""Test tracker handles varying tensor shapes across steps."""
enabled_tracker.track(time=1.0, x_t=torch.randn(1, 50, 6))
enabled_tracker.track(time=0.9, x_t=torch.randn(1, 25, 6))
enabled_tracker.track(time=0.8, x_t=torch.randn(2, 50, 8))
steps = enabled_tracker.get_all_steps()
assert len(steps) == 3
assert steps[0].x_t.shape == (1, 50, 6)
assert steps[1].x_t.shape == (1, 25, 6)
assert steps[2].x_t.shape == (2, 50, 8)