diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 5bffd111c..77a283b70 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -21,7 +21,6 @@ Usage: import logging import os import random -import sys from dataclasses import dataclass, field import matplotlib.pyplot as plt @@ -37,17 +36,7 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer from lerobot.utils.hub import HubMixin - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - force=True, -) -logger = logging.getLogger(__name__) - -# Ensure logs are flushed immediately -sys.stdout.reconfigure(line_buffering=True) -sys.stderr.reconfigure(line_buffering=True) +from lerobot.utils.utils import init_logging def set_seed(seed: int): @@ -62,7 +51,6 @@ def set_seed(seed: int): torch.mps.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - logger.info(f"Random seed set to: {seed}") @dataclass @@ -128,7 +116,7 @@ class RTCEvalConfig(HubMixin): self.device = "mps" else: self.device = "cpu" - logger.info(f"Auto-detected device: {self.device}") + logging.info(f"Auto-detected device: {self.device}") @classmethod def __get_path_fields__(cls) -> list[str]: @@ -144,7 +132,7 @@ class RTCEvaluator: self.device = cfg.device # Load policy - logger.info(f"Loading policy from {cfg.policy.pretrained_path}") + logging.info(f"Loading policy from {cfg.policy.pretrained_path}") policy_class = get_policy_class(cfg.policy.type) self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path) self.policy = self.policy.to(self.device) @@ -156,14 +144,14 @@ class RTCEvaluator: self.policy.config.rtc_config = cfg.rtc self.policy.init_rtc_processor() - logger.info(f"Policy loaded: {self.policy.name}") - logger.info(f"RTC enabled: {cfg.rtc.enabled}") - logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}") + logging.info(f"Policy loaded: {self.policy.name}") + logging.info(f"RTC enabled: {cfg.rtc.enabled}") + logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}") # Load dataset - logger.info(f"Loading dataset: {cfg.dataset.repo_id}") + logging.info(f"Loading dataset: {cfg.dataset.repo_id}") self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) - logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") + logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") # Create preprocessor/postprocessor self.preprocessor, self.postprocessor = make_pre_post_processors( @@ -178,10 +166,10 @@ class RTCEvaluator: """Run evaluation on two random dataset samples.""" # Create output directory os.makedirs(self.cfg.output_dir, exist_ok=True) - logger.info(f"Output directory: {self.cfg.output_dir}") + logging.info(f"Output directory: {self.cfg.output_dir}") - logger.info("Starting RTC evaluation") - logger.info(f"Inference delay: {self.cfg.inference_delay}") + logging.info("Starting RTC evaluation") + logging.info(f"Inference delay: {self.cfg.inference_delay}") data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) loader_iter = iter(data_loader) @@ -202,7 +190,7 @@ class RTCEvaluator: noise_clone = noise.clone() # Generate actions WITHOUT RTC - logger.info("Generating actions WITHOUT RTC") + logging.info("Generating actions WITHOUT RTC") self.policy.config.rtc_config.enabled = False with torch.no_grad(): _ = self.policy.predict_action_chunk( @@ -214,7 +202,7 @@ class RTCEvaluator: self.policy.rtc_processor.reset_tracker() # Generate actions WITH RTC - logger.info("Generating actions WITH RTC") + logging.info("Generating actions WITH RTC") self.policy.config.rtc_config.enabled = True with torch.no_grad(): _ = self.policy.predict_action_chunk( @@ -225,12 +213,10 @@ class RTCEvaluator: execution_horizon=self.cfg.rtc.execution_horizon, ) - # ================================================================ - rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) - logger.info("Evaluation completed successfully") + logging.info("Evaluation completed successfully") def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): # Create side-by-side figures for denoising visualization @@ -262,15 +248,20 @@ class RTCEvaluator: axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) - RTCDebugVisualizer.plot_waypoints( - axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" - ) - # Plot ground truth on x1_t axes RTCDebugVisualizer.plot_waypoints( axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" ) + # Plot ground truth on x_t axes + RTCDebugVisualizer.plot_waypoints( + axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + + RTCDebugVisualizer.plot_waypoints( + axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" + ) + # Save denoising plots self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) @@ -290,7 +281,7 @@ class RTCEvaluator: def _save_figure(self, fig, path): fig.tight_layout() fig.savefig(path, dpi=150) - logger.info(f"Saved figure to {path}") + logging.info(f"Saved figure to {path}") plt.close(fig) def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps): @@ -372,10 +363,12 @@ def main(cfg: RTCEvalConfig): # Set random seed for reproducibility set_seed(cfg.seed) - logger.info("=" * 80) - logger.info("RTC Dataset Evaluation") - logger.info(f"Config: {cfg}") - logger.info("=" * 80) + init_logging() + + logging.info("=" * 80) + logging.info("RTC Dataset Evaluation") + logging.info(f"Config: {cfg}") + logging.info("=" * 80) evaluator = RTCEvaluator(cfg) evaluator.run_evaluation()