mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
Fix traacking
This commit is contained in:
@@ -68,7 +68,7 @@ class RTCEvalConfig(HubMixin):
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=5.0,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=True,
|
||||
debug_maxlen=1000,
|
||||
@@ -184,6 +184,10 @@ class RTCEvaluator:
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
|
||||
self.policy.rtc_processor.reset_tracker()
|
||||
|
||||
logging.info("Resetting tracker")
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
||||
noise = self.policy.model.sample_noise(noise_size, self.device)
|
||||
@@ -300,6 +304,9 @@ class RTCEvaluator:
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info(f"Plotting {len(tracked_steps)} steps")
|
||||
|
||||
debug_steps = tracked_steps
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
@@ -40,7 +40,7 @@ class RTCConfig:
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
|
||||
max_guidance_weight: float = 5.0
|
||||
max_guidance_weight: float = 10.0
|
||||
execution_horizon: int = 10
|
||||
|
||||
# Debug settings
|
||||
|
||||
@@ -260,7 +260,6 @@ class RTCProcessor:
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
# Record debug information (all params except x_t which is recorded externally)
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
|
||||
@@ -798,12 +798,13 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
Reference in New Issue
Block a user