mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
fixup! Refactor plotting loging
This commit is contained in:
@@ -21,7 +21,6 @@ Usage:
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@@ -37,17 +36,7 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
force=True,
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Ensure logs are flushed immediately
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
sys.stderr.reconfigure(line_buffering=True)
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
@@ -62,7 +51,6 @@ def set_seed(seed: int):
|
||||
torch.mps.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
logger.info(f"Random seed set to: {seed}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -128,7 +116,7 @@ class RTCEvalConfig(HubMixin):
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logger.info(f"Auto-detected device: {self.device}")
|
||||
logging.info(f"Auto-detected device: {self.device}")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
@@ -144,7 +132,7 @@ class RTCEvaluator:
|
||||
self.device = cfg.device
|
||||
|
||||
# Load policy
|
||||
logger.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
||||
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
self.policy = self.policy.to(self.device)
|
||||
@@ -156,14 +144,14 @@ class RTCEvaluator:
|
||||
self.policy.config.rtc_config = cfg.rtc
|
||||
self.policy.init_rtc_processor()
|
||||
|
||||
logger.info(f"Policy loaded: {self.policy.name}")
|
||||
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||
logging.info(f"Policy loaded: {self.policy.name}")
|
||||
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
|
||||
# Create preprocessor/postprocessor
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
@@ -178,10 +166,10 @@ class RTCEvaluator:
|
||||
"""Run evaluation on two random dataset samples."""
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logger.info(f"Output directory: {self.cfg.output_dir}")
|
||||
logging.info(f"Output directory: {self.cfg.output_dir}")
|
||||
|
||||
logger.info("Starting RTC evaluation")
|
||||
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
logging.info("Starting RTC evaluation")
|
||||
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
|
||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||
loader_iter = iter(data_loader)
|
||||
@@ -202,7 +190,7 @@ class RTCEvaluator:
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Generate actions WITHOUT RTC
|
||||
logger.info("Generating actions WITHOUT RTC")
|
||||
logging.info("Generating actions WITHOUT RTC")
|
||||
self.policy.config.rtc_config.enabled = False
|
||||
with torch.no_grad():
|
||||
_ = self.policy.predict_action_chunk(
|
||||
@@ -214,7 +202,7 @@ class RTCEvaluator:
|
||||
self.policy.rtc_processor.reset_tracker()
|
||||
|
||||
# Generate actions WITH RTC
|
||||
logger.info("Generating actions WITH RTC")
|
||||
logging.info("Generating actions WITH RTC")
|
||||
self.policy.config.rtc_config.enabled = True
|
||||
with torch.no_grad():
|
||||
_ = self.policy.predict_action_chunk(
|
||||
@@ -225,12 +213,10 @@ class RTCEvaluator:
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
|
||||
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
||||
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
||||
logger.info("Evaluation completed successfully")
|
||||
logging.info("Evaluation completed successfully")
|
||||
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
||||
# Create side-by-side figures for denoising visualization
|
||||
@@ -262,15 +248,20 @@ class RTCEvaluator:
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Save denoising plots
|
||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||
@@ -290,7 +281,7 @@ class RTCEvaluator:
|
||||
def _save_figure(self, fig, path):
|
||||
fig.tight_layout()
|
||||
fig.savefig(path, dpi=150)
|
||||
logger.info(f"Saved figure to {path}")
|
||||
logging.info(f"Saved figure to {path}")
|
||||
plt.close(fig)
|
||||
|
||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||
@@ -372,10 +363,12 @@ def main(cfg: RTCEvalConfig):
|
||||
# Set random seed for reproducibility
|
||||
set_seed(cfg.seed)
|
||||
|
||||
logger.info("=" * 80)
|
||||
logger.info("RTC Dataset Evaluation")
|
||||
logger.info(f"Config: {cfg}")
|
||||
logger.info("=" * 80)
|
||||
init_logging()
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("RTC Dataset Evaluation")
|
||||
logging.info(f"Config: {cfg}")
|
||||
logging.info("=" * 80)
|
||||
|
||||
evaluator = RTCEvaluator(cfg)
|
||||
evaluator.run_evaluation()
|
||||
|
||||
Reference in New Issue
Block a user