fixup! Refactor plotting loging

This commit is contained in:
Eugene Mironov
2025-11-04 02:11:54 +07:00
parent 84df6cd13d
commit aaa308b158
+30 -37
View File
@@ -21,7 +21,6 @@ Usage:
import logging
import os
import random
import sys
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
@@ -37,17 +36,7 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
force=True,
)
logger = logging.getLogger(__name__)
# Ensure logs are flushed immediately
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
from lerobot.utils.utils import init_logging
def set_seed(seed: int):
@@ -62,7 +51,6 @@ def set_seed(seed: int):
torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
logger.info(f"Random seed set to: {seed}")
@dataclass
@@ -128,7 +116,7 @@ class RTCEvalConfig(HubMixin):
self.device = "mps"
else:
self.device = "cpu"
logger.info(f"Auto-detected device: {self.device}")
logging.info(f"Auto-detected device: {self.device}")
@classmethod
def __get_path_fields__(cls) -> list[str]:
@@ -144,7 +132,7 @@ class RTCEvaluator:
self.device = cfg.device
# Load policy
logger.info(f"Loading policy from {cfg.policy.pretrained_path}")
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type)
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
self.policy = self.policy.to(self.device)
@@ -156,14 +144,14 @@ class RTCEvaluator:
self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor()
logger.info(f"Policy loaded: {self.policy.name}")
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
logging.info(f"Policy loaded: {self.policy.name}")
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
# Load dataset
logger.info(f"Loading dataset: {cfg.dataset.repo_id}")
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors(
@@ -178,10 +166,10 @@ class RTCEvaluator:
"""Run evaluation on two random dataset samples."""
# Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True)
logger.info(f"Output directory: {self.cfg.output_dir}")
logging.info(f"Output directory: {self.cfg.output_dir}")
logger.info("Starting RTC evaluation")
logger.info(f"Inference delay: {self.cfg.inference_delay}")
logging.info("Starting RTC evaluation")
logging.info(f"Inference delay: {self.cfg.inference_delay}")
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader)
@@ -202,7 +190,7 @@ class RTCEvaluator:
noise_clone = noise.clone()
# Generate actions WITHOUT RTC
logger.info("Generating actions WITHOUT RTC")
logging.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False
with torch.no_grad():
_ = self.policy.predict_action_chunk(
@@ -214,7 +202,7 @@ class RTCEvaluator:
self.policy.rtc_processor.reset_tracker()
# Generate actions WITH RTC
logger.info("Generating actions WITH RTC")
logging.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True
with torch.no_grad():
_ = self.policy.predict_action_chunk(
@@ -225,12 +213,10 @@ class RTCEvaluator:
execution_horizon=self.cfg.rtc.execution_horizon,
)
# ================================================================
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
logger.info("Evaluation completed successfully")
logging.info("Evaluation completed successfully")
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
# Create side-by-side figures for denoising visualization
@@ -262,15 +248,20 @@ class RTCEvaluator:
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
@@ -290,7 +281,7 @@ class RTCEvaluator:
def _save_figure(self, fig, path):
fig.tight_layout()
fig.savefig(path, dpi=150)
logger.info(f"Saved figure to {path}")
logging.info(f"Saved figure to {path}")
plt.close(fig)
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
@@ -372,10 +363,12 @@ def main(cfg: RTCEvalConfig):
# Set random seed for reproducibility
set_seed(cfg.seed)
logger.info("=" * 80)
logger.info("RTC Dataset Evaluation")
logger.info(f"Config: {cfg}")
logger.info("=" * 80)
init_logging()
logging.info("=" * 80)
logging.info("RTC Dataset Evaluation")
logging.info(f"Config: {cfg}")
logging.info("=" * 80)
evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation()