fixup! Add matplotliv to dev

This commit is contained in:
Eugene Mironov
2025-11-07 01:33:03 +07:00
parent a42fb4d0e2
commit 68b2142bd2
2 changed files with 43 additions and 39 deletions
+38 -39
View File
@@ -27,7 +27,7 @@ Usage:
--use_torch_compile=true \ --use_torch_compile=true \
--torch_compile_mode=max-autotune --torch_compile_mode=max-autotune
# With torch.compile for faster inference (PyTorch 2.0+) # With torch.compile on CUDA
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
--policy.path=helper2424/smolvla_check_rtc_last3 \ --policy.path=helper2424/smolvla_check_rtc_last3 \
--dataset.repo_id=helper2424/check_rtc \ --dataset.repo_id=helper2424/check_rtc \
@@ -202,31 +202,6 @@ class RTCEvaluator:
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory") logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
logging.info("=" * 80) logging.info("=" * 80)
# Policy 1: For generating previous chunk (RTC disabled, no debug)
self.policy_prev_chunk = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
self.policy_no_rtc = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Policy 3: For RTC inference (RTC enabled, debug enabled)
self.policy_rtc = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
logging.info("=" * 80)
logging.info("All policies initialized successfully")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool): def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
"""Initialize a single policy instance with specified RTC configuration. """Initialize a single policy instance with specified RTC configuration.
@@ -290,8 +265,11 @@ class RTCEvaluator:
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...") logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
logging.info(f" Backend: {self.cfg.torch_compile_backend}") logging.info(f" Backend: {self.cfg.torch_compile_backend}")
logging.info(f" Mode: {self.cfg.torch_compile_mode}") logging.info(f" Mode: {self.cfg.torch_compile_mode}")
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
# Compile the predict_action_chunk method # Compile the predict_action_chunk method
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator
# on the Tracker.track() method, so it won't cause graph breaks
original_method = policy.predict_action_chunk original_method = policy.predict_action_chunk
compiled_method = torch.compile( compiled_method = torch.compile(
original_method, original_method,
@@ -368,56 +346,77 @@ class RTCEvaluator:
preprocessed_first_sample = self.preprocessor(first_sample) preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample) preprocessed_second_sample = self.preprocessor(second_sample)
# Policy 1: For generating previous chunk (RTC disabled, no debug)
policy_prev_chunk_policy = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Policy 3: For RTC inference (RTC enabled, debug enabled)
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
# Step 1: Generate previous chunk using policy_prev_chunk # Step 1: Generate previous chunk using policy_prev_chunk
# This policy is only used to generate the reference chunk and then freed # This policy is only used to generate the reference chunk and then freed
logging.info("Step 1: Generating previous chunk with policy_prev_chunk") logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
with torch.no_grad(): with torch.no_grad():
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk( prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample, preprocessed_first_sample,
)[:, :25, :].squeeze(0) )[:, :25, :].squeeze(0)
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}") logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
# Destroy policy_prev_chunk to free memory for large models # Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk") self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
# Sample noise (use same noise for both RTC and non-RTC for fair comparison) # Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim) noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device) noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone() noise_clone = noise.clone()
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc # Step 2: Generate actions WITHOUT RTC using policy_no_rtc
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc") logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
self.policy_no_rtc.rtc_processor.reset_tracker() policy_no_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad(): with torch.no_grad():
_ = self.policy_no_rtc.predict_action_chunk( _ = policy_no_rtc_policy.predict_action_chunk(
preprocessed_second_sample, preprocessed_second_sample,
noise=noise, noise=noise,
) )
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps() no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC") logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
# Destroy policy_no_rtc to free memory before loading policy_rtc # Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc") self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
# Step 3: Generate actions WITH RTC using policy_rtc # Step 3: Generate actions WITH RTC using policy_rtc
logging.info("Step 3: Generating actions WITH RTC with policy_rtc") logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
self.policy_rtc.rtc_processor.reset_tracker() policy_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad(): with torch.no_grad():
_ = self.policy_rtc.predict_action_chunk( _ = policy_rtc_policy.predict_action_chunk(
preprocessed_second_sample, preprocessed_second_sample,
noise=noise_clone, noise=noise_clone,
inference_delay=self.cfg.inference_delay, inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over, prev_chunk_left_over=prev_chunk_left_over,
execution_horizon=self.cfg.rtc.execution_horizon, execution_horizon=self.cfg.rtc.execution_horizon,
) )
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps() rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC") logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
# Save num_steps before destroying policy (needed for plotting) # Save num_steps before destroying policy (needed for plotting)
num_steps = self.policy_rtc.config.num_steps num_steps = policy_rtc_policy.config.num_steps
# Destroy policy_rtc after final use # Destroy policy_rtc after final use
self._destroy_policy(self.policy_rtc, "policy_rtc") self._destroy_policy(policy_rtc_policy, "policy_rtc")
# Plot and save results # Plot and save results
logging.info("=" * 80) logging.info("=" * 80)
@@ -19,6 +19,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
import torch
from torch import Tensor from torch import Tensor
@@ -120,6 +121,7 @@ class Tracker:
self._steps.clear() self._steps.clear()
self._step_counter = 0 self._step_counter = 0
@torch._dynamo.disable
def track( def track(
self, self,
time: float | Tensor, time: float | Tensor,
@@ -139,6 +141,9 @@ class Tracker:
If a step with the given time already exists, it will be updated with the new data. If a step with the given time already exists, it will be updated with the new data.
Otherwise, a new step will be created. Only non-None fields are updated/set. Otherwise, a new step will be created. Only non-None fields are updated/set.
Note: This method is excluded from torch.compile to avoid graph breaks from
operations like .item() which are incompatible with compiled graphs.
Args: Args:
time (float | Tensor): Time parameter - used as the key to identify the step. time (float | Tensor): Time parameter - used as the key to identify the step.
x_t (Tensor | None): Current latent/state tensor. x_t (Tensor | None): Current latent/state tensor.