fixup! Refactor plotting loging

This commit is contained in:
Eugene Mironov
2025-11-04 02:11:54 +07:00
parent 84df6cd13d
commit aaa308b158
+30 -37
View File
@@ -21,7 +21,6 @@ Usage:
import logging import logging
import os import os
import random import random
import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@@ -37,17 +36,7 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
from lerobot.utils.hub import HubMixin from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
force=True,
)
logger = logging.getLogger(__name__)
# Ensure logs are flushed immediately
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def set_seed(seed: int): def set_seed(seed: int):
@@ -62,7 +51,6 @@ def set_seed(seed: int):
torch.mps.manual_seed(seed) torch.mps.manual_seed(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
logger.info(f"Random seed set to: {seed}")
@dataclass @dataclass
@@ -128,7 +116,7 @@ class RTCEvalConfig(HubMixin):
self.device = "mps" self.device = "mps"
else: else:
self.device = "cpu" self.device = "cpu"
logger.info(f"Auto-detected device: {self.device}") logging.info(f"Auto-detected device: {self.device}")
@classmethod @classmethod
def __get_path_fields__(cls) -> list[str]: def __get_path_fields__(cls) -> list[str]:
@@ -144,7 +132,7 @@ class RTCEvaluator:
self.device = cfg.device self.device = cfg.device
# Load policy # Load policy
logger.info(f"Loading policy from {cfg.policy.pretrained_path}") logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
policy_class = get_policy_class(cfg.policy.type) policy_class = get_policy_class(cfg.policy.type)
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path) self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
self.policy = self.policy.to(self.device) self.policy = self.policy.to(self.device)
@@ -156,14 +144,14 @@ class RTCEvaluator:
self.policy.config.rtc_config = cfg.rtc self.policy.config.rtc_config = cfg.rtc
self.policy.init_rtc_processor() self.policy.init_rtc_processor()
logger.info(f"Policy loaded: {self.policy.name}") logging.info(f"Policy loaded: {self.policy.name}")
logger.info(f"RTC enabled: {cfg.rtc.enabled}") logging.info(f"RTC enabled: {cfg.rtc.enabled}")
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}") logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
# Load dataset # Load dataset
logger.info(f"Loading dataset: {cfg.dataset.repo_id}") logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30}) self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes") logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor # Create preprocessor/postprocessor
self.preprocessor, self.postprocessor = make_pre_post_processors( self.preprocessor, self.postprocessor = make_pre_post_processors(
@@ -178,10 +166,10 @@ class RTCEvaluator:
"""Run evaluation on two random dataset samples.""" """Run evaluation on two random dataset samples."""
# Create output directory # Create output directory
os.makedirs(self.cfg.output_dir, exist_ok=True) os.makedirs(self.cfg.output_dir, exist_ok=True)
logger.info(f"Output directory: {self.cfg.output_dir}") logging.info(f"Output directory: {self.cfg.output_dir}")
logger.info("Starting RTC evaluation") logging.info("Starting RTC evaluation")
logger.info(f"Inference delay: {self.cfg.inference_delay}") logging.info(f"Inference delay: {self.cfg.inference_delay}")
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True) data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
loader_iter = iter(data_loader) loader_iter = iter(data_loader)
@@ -202,7 +190,7 @@ class RTCEvaluator:
noise_clone = noise.clone() noise_clone = noise.clone()
# Generate actions WITHOUT RTC # Generate actions WITHOUT RTC
logger.info("Generating actions WITHOUT RTC") logging.info("Generating actions WITHOUT RTC")
self.policy.config.rtc_config.enabled = False self.policy.config.rtc_config.enabled = False
with torch.no_grad(): with torch.no_grad():
_ = self.policy.predict_action_chunk( _ = self.policy.predict_action_chunk(
@@ -214,7 +202,7 @@ class RTCEvaluator:
self.policy.rtc_processor.reset_tracker() self.policy.rtc_processor.reset_tracker()
# Generate actions WITH RTC # Generate actions WITH RTC
logger.info("Generating actions WITH RTC") logging.info("Generating actions WITH RTC")
self.policy.config.rtc_config.enabled = True self.policy.config.rtc_config.enabled = True
with torch.no_grad(): with torch.no_grad():
_ = self.policy.predict_action_chunk( _ = self.policy.predict_action_chunk(
@@ -225,12 +213,10 @@ class RTCEvaluator:
execution_horizon=self.cfg.rtc.execution_horizon, execution_horizon=self.cfg.rtc.execution_horizon,
) )
# ================================================================
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps() rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over) self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
logger.info("Evaluation completed successfully") logging.info("Evaluation completed successfully")
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over): def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
# Create side-by-side figures for denoising visualization # Create side-by-side figures for denoising visualization
@@ -262,15 +248,20 @@ class RTCEvaluator:
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
) )
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Plot ground truth on x1_t axes # Plot ground truth on x1_t axes
RTCDebugVisualizer.plot_waypoints( RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth" axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
) )
# Plot ground truth on x_t axes
RTCDebugVisualizer.plot_waypoints(
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
RTCDebugVisualizer.plot_waypoints(
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
)
# Save denoising plots # Save denoising plots
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png")) self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png")) self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
@@ -290,7 +281,7 @@ class RTCEvaluator:
def _save_figure(self, fig, path): def _save_figure(self, fig, path):
fig.tight_layout() fig.tight_layout()
fig.savefig(path, dpi=150) fig.savefig(path, dpi=150)
logger.info(f"Saved figure to {path}") logging.info(f"Saved figure to {path}")
plt.close(fig) plt.close(fig)
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps): def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
@@ -372,10 +363,12 @@ def main(cfg: RTCEvalConfig):
# Set random seed for reproducibility # Set random seed for reproducibility
set_seed(cfg.seed) set_seed(cfg.seed)
logger.info("=" * 80) init_logging()
logger.info("RTC Dataset Evaluation")
logger.info(f"Config: {cfg}") logging.info("=" * 80)
logger.info("=" * 80) logging.info("RTC Dataset Evaluation")
logging.info(f"Config: {cfg}")
logging.info("=" * 80)
evaluator = RTCEvaluator(cfg) evaluator = RTCEvaluator(cfg)
evaluator.run_evaluation() evaluator.run_evaluation()