Use output_dir for saving all evaluation images

Update eval_dataset.py to save all comparison images to the
configured output_dir instead of the current directory. This provides
better organization and allows users to specify where outputs should be
saved.

Changes:
- Add os import at top level
- Create output_dir at start of run_evaluation()
- Save all comparison images to output_dir
- Remove duplicate os imports
- Update init_rtc_processor() docstring to be more concise

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com>
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Eugene Mironov
2025-11-03 18:55:12 +07:00
parent d20ef2e46e
commit c409ed2d1d
2 changed files with 33 additions and 51 deletions
+31 -45
View File
@@ -19,6 +19,7 @@ Usage:
""" """
import logging import logging
import os
import random import random
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -91,14 +92,6 @@ class RTCEvalConfig(HubMixin):
default="rtc_debug_output", default="rtc_debug_output",
metadata={"help": "Directory to save debug visualizations"}, metadata={"help": "Directory to save debug visualizations"},
) )
verbose: bool = field(
default=False,
metadata={"help": "Enable verbose logging"},
)
enable_debug_viz: bool = field(
default=True,
metadata={"help": "Enable debug visualization"},
)
# Seed configuration # Seed configuration
seed: int = field( seed: int = field(
@@ -154,7 +147,7 @@ class RTCEvaluator:
# Configure RTC # Configure RTC
cfg.rtc.enabled = True cfg.rtc.enabled = True
self.policy.config.rtc_config = cfg.rtc self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor(verbose=cfg.verbose) self.policy.init_rtc_processor()
logger.info(f"Policy loaded: {self.policy.name}") logger.info(f"Policy loaded: {self.policy.name}")
logger.info(f"RTC enabled: {cfg.rtc.enabled}") logger.info(f"RTC enabled: {cfg.rtc.enabled}")
@@ -176,31 +169,25 @@ class RTCEvaluator:
def run_evaluation(self): def run_evaluation(self):
"""Run evaluation on two random dataset samples.""" """Run evaluation on two random dataset samples."""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logger.info(f"Output directory: {self.cfg.output_dir}")
logger.info("Starting RTC evaluation") logger.info("Starting RTC evaluation")
logger.info(f"Inference delay: {self.cfg.inference_delay}") logger.info(f"Inference delay: {self.cfg.inference_delay}")
# Get two random samples from the dataset data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
idx1, idx2 = random.sample(range(len(self.dataset)), 2) loader_iter = iter(data_loader)
logger.info(f"Selected samples: {idx1}, {idx2}") first_sample = next(loader_iter)
second_sample = next(loader_iter)
# Get first sample - use its actions as prev_chunk preprocessed_first_sample = self.preprocessor(first_sample)
sample1 = self.dataset[idx1] preprocessed_second_sample = self.preprocessor(second_sample)
for key, value in sample1.items():
if isinstance(value, torch.Tensor):
sample1[key] = value.unsqueeze(0).to(self.device)
preprocessed_sample1 = self.preprocessor(sample1) # Don't postprocess the previous chunk
prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25] prev_chunk_left_over = self.policy.predict_action_chunk(
logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}") preprocessed_first_sample,
)[:, :25, :].squeeze(0)
# Get second sample - generate actions for this one
sample2 = self.dataset[idx2]
for key, value in sample2.items():
if isinstance(value, torch.Tensor):
sample2[key] = value.unsqueeze(0).to(self.device)
preprocessed_sample2 = self.preprocessor(sample2)
logger.info(f"Generating actions for sample {idx2}")
# Sample noise (use same noise for both RTC and non-RTC for fair comparison) # Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim) noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
@@ -222,10 +209,8 @@ class RTCEvaluator:
self.policy.config.rtc_config.enabled = False self.policy.config.rtc_config.enabled = False
with torch.no_grad(): with torch.no_grad():
no_rtc_actions = self.policy.predict_action_chunk( no_rtc_actions = self.policy.predict_action_chunk(
preprocessed_sample2, preprocessed_second_sample,
noise=noise, noise=noise,
inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over,
viz_xt_axs=axs_xt[:, 0], # Left column for x_t viz_xt_axs=axs_xt[:, 0], # Left column for x_t
viz_vt_axs=axs_vt[:, 0], # Left column for v_t viz_vt_axs=axs_vt[:, 0], # Left column for v_t
) )
@@ -235,7 +220,7 @@ class RTCEvaluator:
self.policy.config.rtc_config.enabled = True self.policy.config.rtc_config.enabled = True
with torch.no_grad(): with torch.no_grad():
rtc_actions = self.policy.predict_action_chunk( rtc_actions = self.policy.predict_action_chunk(
preprocessed_sample2, preprocessed_second_sample,
noise=noise_clone, noise=noise_clone,
inference_delay=self.cfg.inference_delay, inference_delay=self.cfg.inference_delay,
prev_chunk_left_over=prev_chunk_left_over, prev_chunk_left_over=prev_chunk_left_over,
@@ -263,18 +248,21 @@ class RTCEvaluator:
# Save denoising plots # Save denoising plots
fig_xt.tight_layout() fig_xt.tight_layout()
fig_xt.savefig("denoising_xt_comparison.png", dpi=150) xt_path = os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")
logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png") fig_xt.savefig(xt_path, dpi=150)
logger.info(f"Saved x_t denoising comparison to {xt_path}")
plt.close(fig_xt) plt.close(fig_xt)
fig_vt.tight_layout() fig_vt.tight_layout()
fig_vt.savefig("denoising_vt_comparison.png", dpi=150) vt_path = os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")
logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png") fig_vt.savefig(vt_path, dpi=150)
logger.info(f"Saved v_t denoising comparison to {vt_path}")
plt.close(fig_vt) plt.close(fig_vt)
fig_x1t.tight_layout() fig_x1t.tight_layout()
fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150) x1t_path = os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")
logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png") fig_x1t.savefig(x1t_path, dpi=150)
logger.info(f"Saved x1_t predicted state & error comparison to {x1t_path}")
plt.close(fig_x1t) plt.close(fig_x1t)
# Create side-by-side comparison: No RTC (left) vs RTC (right) # Create side-by-side comparison: No RTC (left) vs RTC (right)
@@ -298,13 +286,13 @@ class RTCEvaluator:
) )
plt.tight_layout() plt.tight_layout()
plt.savefig("final_actions_comparison.png", dpi=150) final_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
logger.info("Saved final actions comparison to final_actions_comparison.png") plt.savefig(final_path, dpi=150)
logger.info(f"Saved final actions comparison to {final_path}")
plt.close(fig) plt.close(fig)
# Visualize debug information if enabled # Visualize debug information if enabled
if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None: self._visualize_debug_info()
self._visualize_debug_info()
logger.info("Evaluation completed successfully") logger.info("Evaluation completed successfully")
@@ -338,8 +326,6 @@ class RTCEvaluator:
def _visualize_debug_info(self): def _visualize_debug_info(self):
"""Visualize debug information from the RTC processor.""" """Visualize debug information from the RTC processor."""
import os
# Use proxy method to check if debug is enabled # Use proxy method to check if debug is enabled
if not self.policy.rtc_processor.is_debug_enabled(): if not self.policy.rtc_processor.is_debug_enabled():
logger.warning("Debug tracking is disabled. Skipping debug visualization.") logger.warning("Debug tracking is disabled. Skipping debug visualization.")
@@ -249,12 +249,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps), ACTION: deque(maxlen=self.config.n_action_steps),
} }
def init_rtc_processor(self, verbose: bool = False): def init_rtc_processor(self):
"""Initialize RTC processor with optional verbose logging. """Initialize RTC processor if RTC is enabled in config."""
Args:
verbose: Enable verbose debug logging in RTCProcessor (currently unused)
"""
self.rtc_processor = None self.rtc_processor = None
if self.config.rtc_config is not None and self.config.rtc_config.enabled: if self.config.rtc_config is not None and self.config.rtc_config.enabled: