mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
Use output_dir for saving all evaluation images
Update eval_dataset.py to save all comparison images to the configured output_dir instead of the current directory. This provides better organization and allows users to specify where outputs should be saved. Changes: - Add os import at top level - Create output_dir at start of run_evaluation() - Save all comparison images to output_dir - Remove duplicate os imports - Update init_rtc_processor() docstring to be more concise 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,7 @@ Usage:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -91,14 +92,6 @@ class RTCEvalConfig(HubMixin):
|
||||
default="rtc_debug_output",
|
||||
metadata={"help": "Directory to save debug visualizations"},
|
||||
)
|
||||
verbose: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable verbose logging"},
|
||||
)
|
||||
enable_debug_viz: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Enable debug visualization"},
|
||||
)
|
||||
|
||||
# Seed configuration
|
||||
seed: int = field(
|
||||
@@ -154,7 +147,7 @@ class RTCEvaluator:
|
||||
# Configure RTC
|
||||
cfg.rtc.enabled = True
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor(verbose=cfg.verbose)
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
logger.info(f"Policy loaded: {self.policy.name}")
|
||||
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
@@ -176,31 +169,25 @@ class RTCEvaluator:
|
||||
|
||||
def run_evaluation(self):
|
||||
"""Run evaluation on two random dataset samples."""
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logger.info(f"Output directory: {self.cfg.output_dir}")
|
||||
|
||||
logger.info("Starting RTC evaluation")
|
||||
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
|
||||
# Get two random samples from the dataset
|
||||
idx1, idx2 = random.sample(range(len(self.dataset)), 2)
|
||||
logger.info(f"Selected samples: {idx1}, {idx2}")
|
||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||
loader_iter = iter(data_loader)
|
||||
first_sample = next(loader_iter)
|
||||
second_sample = next(loader_iter)
|
||||
|
||||
# Get first sample - use its actions as prev_chunk
|
||||
sample1 = self.dataset[idx1]
|
||||
for key, value in sample1.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
sample1[key] = value.unsqueeze(0).to(self.device)
|
||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||
|
||||
preprocessed_sample1 = self.preprocessor(sample1)
|
||||
prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25]
|
||||
logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}")
|
||||
|
||||
# Get second sample - generate actions for this one
|
||||
sample2 = self.dataset[idx2]
|
||||
for key, value in sample2.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
sample2[key] = value.unsqueeze(0).to(self.device)
|
||||
|
||||
preprocessed_sample2 = self.preprocessor(sample2)
|
||||
logger.info(f"Generating actions for sample {idx2}")
|
||||
# Don't postprocess the previous chunk
|
||||
prev_chunk_left_over = self.policy.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
||||
@@ -222,10 +209,8 @@ class RTCEvaluator:
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_sample2,
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
||||
)
|
||||
@@ -235,7 +220,7 @@ class RTCEvaluator:
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
rtc_actions = self.policy.predict_action_chunk(
|
||||
preprocessed_sample2,
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
@@ -263,18 +248,21 @@ class RTCEvaluator:
|
||||
|
||||
# Save denoising plots
|
||||
fig_xt.tight_layout()
|
||||
fig_xt.savefig("denoising_xt_comparison.png", dpi=150)
|
||||
logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png")
|
||||
xt_path = os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")
|
||||
fig_xt.savefig(xt_path, dpi=150)
|
||||
logger.info(f"Saved x_t denoising comparison to {xt_path}")
|
||||
plt.close(fig_xt)
|
||||
|
||||
fig_vt.tight_layout()
|
||||
fig_vt.savefig("denoising_vt_comparison.png", dpi=150)
|
||||
logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png")
|
||||
vt_path = os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")
|
||||
fig_vt.savefig(vt_path, dpi=150)
|
||||
logger.info(f"Saved v_t denoising comparison to {vt_path}")
|
||||
plt.close(fig_vt)
|
||||
|
||||
fig_x1t.tight_layout()
|
||||
fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150)
|
||||
logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png")
|
||||
x1t_path = os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")
|
||||
fig_x1t.savefig(x1t_path, dpi=150)
|
||||
logger.info(f"Saved x1_t predicted state & error comparison to {x1t_path}")
|
||||
plt.close(fig_x1t)
|
||||
|
||||
# Create side-by-side comparison: No RTC (left) vs RTC (right)
|
||||
@@ -298,13 +286,13 @@ class RTCEvaluator:
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("final_actions_comparison.png", dpi=150)
|
||||
logger.info("Saved final actions comparison to final_actions_comparison.png")
|
||||
final_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||
plt.savefig(final_path, dpi=150)
|
||||
logger.info(f"Saved final actions comparison to {final_path}")
|
||||
plt.close(fig)
|
||||
|
||||
# Visualize debug information if enabled
|
||||
if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None:
|
||||
self._visualize_debug_info()
|
||||
self._visualize_debug_info()
|
||||
|
||||
logger.info("Evaluation completed successfully")
|
||||
|
||||
@@ -338,8 +326,6 @@ class RTCEvaluator:
|
||||
|
||||
def _visualize_debug_info(self):
|
||||
"""Visualize debug information from the RTC processor."""
|
||||
import os
|
||||
|
||||
# Use proxy method to check if debug is enabled
|
||||
if not self.policy.rtc_processor.is_debug_enabled():
|
||||
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
|
||||
|
||||
Reference in New Issue
Block a user