mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Use output_dir for saving all evaluation images
Update eval_dataset.py to save all comparison images to the configured output_dir instead of the current directory. This provides better organization and allows users to specify where outputs should be saved. Changes: - Add os import at top level - Create output_dir at start of run_evaluation() - Save all comparison images to output_dir - Remove duplicate os imports - Update init_rtc_processor() docstring to be more concise 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Alexander Soare <alexander.soare159@gmail.com> Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -19,6 +19,7 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
@@ -91,14 +92,6 @@ class RTCEvalConfig(HubMixin):
|
|||||||
default="rtc_debug_output",
|
default="rtc_debug_output",
|
||||||
metadata={"help": "Directory to save debug visualizations"},
|
metadata={"help": "Directory to save debug visualizations"},
|
||||||
)
|
)
|
||||||
verbose: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Enable verbose logging"},
|
|
||||||
)
|
|
||||||
enable_debug_viz: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Enable debug visualization"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seed configuration
|
# Seed configuration
|
||||||
seed: int = field(
|
seed: int = field(
|
||||||
@@ -154,7 +147,7 @@ class RTCEvaluator:
|
|||||||
# Configure RTC
|
# Configure RTC
|
||||||
cfg.rtc.enabled = True
|
cfg.rtc.enabled = True
|
||||||
self.policy.config.rtc_config = cfg.rtc
|
self.policy.config.rtc_config = cfg.rtc
|
||||||
self.policy.init_rtc_processor(verbose=cfg.verbose)
|
self.policy.init_rtc_processor()
|
||||||
|
|
||||||
logger.info(f"Policy loaded: {self.policy.name}")
|
logger.info(f"Policy loaded: {self.policy.name}")
|
||||||
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||||
@@ -176,31 +169,25 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
def run_evaluation(self):
|
def run_evaluation(self):
|
||||||
"""Run evaluation on two random dataset samples."""
|
"""Run evaluation on two random dataset samples."""
|
||||||
|
# Create output directory
|
||||||
|
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||||
|
logger.info(f"Output directory: {self.cfg.output_dir}")
|
||||||
|
|
||||||
logger.info("Starting RTC evaluation")
|
logger.info("Starting RTC evaluation")
|
||||||
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||||
|
|
||||||
# Get two random samples from the dataset
|
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||||
idx1, idx2 = random.sample(range(len(self.dataset)), 2)
|
loader_iter = iter(data_loader)
|
||||||
logger.info(f"Selected samples: {idx1}, {idx2}")
|
first_sample = next(loader_iter)
|
||||||
|
second_sample = next(loader_iter)
|
||||||
|
|
||||||
# Get first sample - use its actions as prev_chunk
|
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||||
sample1 = self.dataset[idx1]
|
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||||
for key, value in sample1.items():
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
sample1[key] = value.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
preprocessed_sample1 = self.preprocessor(sample1)
|
# Don't postprocess the previous chunk
|
||||||
prev_chunk_left_over = preprocessed_sample1["action"][0, :, :25]
|
prev_chunk_left_over = self.policy.predict_action_chunk(
|
||||||
logger.info(f"Using actions from sample {idx1} as previous chunk: shape={prev_chunk_left_over.shape}")
|
preprocessed_first_sample,
|
||||||
|
)[:, :25, :].squeeze(0)
|
||||||
# Get second sample - generate actions for this one
|
|
||||||
sample2 = self.dataset[idx2]
|
|
||||||
for key, value in sample2.items():
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
sample2[key] = value.unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
preprocessed_sample2 = self.preprocessor(sample2)
|
|
||||||
logger.info(f"Generating actions for sample {idx2}")
|
|
||||||
|
|
||||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||||
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
noise_size = (1, self.policy.config.chunk_size, self.policy.config.max_action_dim)
|
||||||
@@ -222,10 +209,8 @@ class RTCEvaluator:
|
|||||||
self.policy.config.rtc_config.enabled = False
|
self.policy.config.rtc_config.enabled = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
no_rtc_actions = self.policy.predict_action_chunk(
|
no_rtc_actions = self.policy.predict_action_chunk(
|
||||||
preprocessed_sample2,
|
preprocessed_second_sample,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
inference_delay=self.cfg.inference_delay,
|
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
|
||||||
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
viz_xt_axs=axs_xt[:, 0], # Left column for x_t
|
||||||
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
viz_vt_axs=axs_vt[:, 0], # Left column for v_t
|
||||||
)
|
)
|
||||||
@@ -235,7 +220,7 @@ class RTCEvaluator:
|
|||||||
self.policy.config.rtc_config.enabled = True
|
self.policy.config.rtc_config.enabled = True
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
rtc_actions = self.policy.predict_action_chunk(
|
rtc_actions = self.policy.predict_action_chunk(
|
||||||
preprocessed_sample2,
|
preprocessed_second_sample,
|
||||||
noise=noise_clone,
|
noise=noise_clone,
|
||||||
inference_delay=self.cfg.inference_delay,
|
inference_delay=self.cfg.inference_delay,
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
@@ -263,18 +248,21 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
# Save denoising plots
|
# Save denoising plots
|
||||||
fig_xt.tight_layout()
|
fig_xt.tight_layout()
|
||||||
fig_xt.savefig("denoising_xt_comparison.png", dpi=150)
|
xt_path = os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")
|
||||||
logger.info("Saved x_t denoising comparison to denoising_xt_comparison.png")
|
fig_xt.savefig(xt_path, dpi=150)
|
||||||
|
logger.info(f"Saved x_t denoising comparison to {xt_path}")
|
||||||
plt.close(fig_xt)
|
plt.close(fig_xt)
|
||||||
|
|
||||||
fig_vt.tight_layout()
|
fig_vt.tight_layout()
|
||||||
fig_vt.savefig("denoising_vt_comparison.png", dpi=150)
|
vt_path = os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")
|
||||||
logger.info("Saved v_t denoising comparison to denoising_vt_comparison.png")
|
fig_vt.savefig(vt_path, dpi=150)
|
||||||
|
logger.info(f"Saved v_t denoising comparison to {vt_path}")
|
||||||
plt.close(fig_vt)
|
plt.close(fig_vt)
|
||||||
|
|
||||||
fig_x1t.tight_layout()
|
fig_x1t.tight_layout()
|
||||||
fig_x1t.savefig("denoising_x1t_comparison.png", dpi=150)
|
x1t_path = os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png")
|
||||||
logger.info("Saved x1_t predicted state & error comparison to denoising_x1t_comparison.png")
|
fig_x1t.savefig(x1t_path, dpi=150)
|
||||||
|
logger.info(f"Saved x1_t predicted state & error comparison to {x1t_path}")
|
||||||
plt.close(fig_x1t)
|
plt.close(fig_x1t)
|
||||||
|
|
||||||
# Create side-by-side comparison: No RTC (left) vs RTC (right)
|
# Create side-by-side comparison: No RTC (left) vs RTC (right)
|
||||||
@@ -298,13 +286,13 @@ class RTCEvaluator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig("final_actions_comparison.png", dpi=150)
|
final_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||||
logger.info("Saved final actions comparison to final_actions_comparison.png")
|
plt.savefig(final_path, dpi=150)
|
||||||
|
logger.info(f"Saved final actions comparison to {final_path}")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
# Visualize debug information if enabled
|
# Visualize debug information if enabled
|
||||||
if self.cfg.enable_debug_viz and self.policy.rtc_processor is not None:
|
self._visualize_debug_info()
|
||||||
self._visualize_debug_info()
|
|
||||||
|
|
||||||
logger.info("Evaluation completed successfully")
|
logger.info("Evaluation completed successfully")
|
||||||
|
|
||||||
@@ -338,8 +326,6 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
def _visualize_debug_info(self):
|
def _visualize_debug_info(self):
|
||||||
"""Visualize debug information from the RTC processor."""
|
"""Visualize debug information from the RTC processor."""
|
||||||
import os
|
|
||||||
|
|
||||||
# Use proxy method to check if debug is enabled
|
# Use proxy method to check if debug is enabled
|
||||||
if not self.policy.rtc_processor.is_debug_enabled():
|
if not self.policy.rtc_processor.is_debug_enabled():
|
||||||
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
|
logger.warning("Debug tracking is disabled. Skipping debug visualization.")
|
||||||
|
|||||||
@@ -249,12 +249,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
def init_rtc_processor(self, verbose: bool = False):
|
def init_rtc_processor(self):
|
||||||
"""Initialize RTC processor with optional verbose logging.
|
"""Initialize RTC processor if RTC is enabled in config."""
|
||||||
|
|
||||||
Args:
|
|
||||||
verbose: Enable verbose debug logging in RTCProcessor (currently unused)
|
|
||||||
"""
|
|
||||||
self.rtc_processor = None
|
self.rtc_processor = None
|
||||||
|
|
||||||
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
if self.config.rtc_config is not None and self.config.rtc_config.enabled:
|
||||||
|
|||||||
Reference in New Issue
Block a user