mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fixup! Refactor plotting loging
This commit is contained in:
@@ -21,7 +21,6 @@ Usage:
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@@ -37,17 +36,7 @@ from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
|||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
||||||
force=True,
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Ensure logs are flushed immediately
|
|
||||||
sys.stdout.reconfigure(line_buffering=True)
|
|
||||||
sys.stderr.reconfigure(line_buffering=True)
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
@@ -62,7 +51,6 @@ def set_seed(seed: int):
|
|||||||
torch.mps.manual_seed(seed)
|
torch.mps.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
logger.info(f"Random seed set to: {seed}")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -128,7 +116,7 @@ class RTCEvalConfig(HubMixin):
|
|||||||
self.device = "mps"
|
self.device = "mps"
|
||||||
else:
|
else:
|
||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
logger.info(f"Auto-detected device: {self.device}")
|
logging.info(f"Auto-detected device: {self.device}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
@@ -144,7 +132,7 @@ class RTCEvaluator:
|
|||||||
self.device = cfg.device
|
self.device = cfg.device
|
||||||
|
|
||||||
# Load policy
|
# Load policy
|
||||||
logger.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
logging.info(f"Loading policy from {cfg.policy.pretrained_path}")
|
||||||
policy_class = get_policy_class(cfg.policy.type)
|
policy_class = get_policy_class(cfg.policy.type)
|
||||||
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
self.policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||||
self.policy = self.policy.to(self.device)
|
self.policy = self.policy.to(self.device)
|
||||||
@@ -156,14 +144,14 @@ class RTCEvaluator:
|
|||||||
self.policy.config.rtc_config = cfg.rtc
|
self.policy.config.rtc_config = cfg.rtc
|
||||||
self.policy.init_rtc_processor()
|
self.policy.init_rtc_processor()
|
||||||
|
|
||||||
logger.info(f"Policy loaded: {self.policy.name}")
|
logging.info(f"Policy loaded: {self.policy.name}")
|
||||||
logger.info(f"RTC enabled: {cfg.rtc.enabled}")
|
logging.info(f"RTC enabled: {cfg.rtc.enabled}")
|
||||||
logger.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
logging.info(f"Execution horizon: {cfg.rtc.execution_horizon}")
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
logger.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||||
logger.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||||
|
|
||||||
# Create preprocessor/postprocessor
|
# Create preprocessor/postprocessor
|
||||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||||
@@ -178,10 +166,10 @@ class RTCEvaluator:
|
|||||||
"""Run evaluation on two random dataset samples."""
|
"""Run evaluation on two random dataset samples."""
|
||||||
# Create output directory
|
# Create output directory
|
||||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||||
logger.info(f"Output directory: {self.cfg.output_dir}")
|
logging.info(f"Output directory: {self.cfg.output_dir}")
|
||||||
|
|
||||||
logger.info("Starting RTC evaluation")
|
logging.info("Starting RTC evaluation")
|
||||||
logger.info(f"Inference delay: {self.cfg.inference_delay}")
|
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||||
|
|
||||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||||
loader_iter = iter(data_loader)
|
loader_iter = iter(data_loader)
|
||||||
@@ -202,7 +190,7 @@ class RTCEvaluator:
|
|||||||
noise_clone = noise.clone()
|
noise_clone = noise.clone()
|
||||||
|
|
||||||
# Generate actions WITHOUT RTC
|
# Generate actions WITHOUT RTC
|
||||||
logger.info("Generating actions WITHOUT RTC")
|
logging.info("Generating actions WITHOUT RTC")
|
||||||
self.policy.config.rtc_config.enabled = False
|
self.policy.config.rtc_config.enabled = False
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy.predict_action_chunk(
|
_ = self.policy.predict_action_chunk(
|
||||||
@@ -214,7 +202,7 @@ class RTCEvaluator:
|
|||||||
self.policy.rtc_processor.reset_tracker()
|
self.policy.rtc_processor.reset_tracker()
|
||||||
|
|
||||||
# Generate actions WITH RTC
|
# Generate actions WITH RTC
|
||||||
logger.info("Generating actions WITH RTC")
|
logging.info("Generating actions WITH RTC")
|
||||||
self.policy.config.rtc_config.enabled = True
|
self.policy.config.rtc_config.enabled = True
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy.predict_action_chunk(
|
_ = self.policy.predict_action_chunk(
|
||||||
@@ -225,12 +213,10 @@ class RTCEvaluator:
|
|||||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ================================================================
|
|
||||||
|
|
||||||
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
rtc_tracked_steps = self.policy.rtc_processor.get_all_debug_steps()
|
||||||
|
|
||||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over)
|
||||||
logger.info("Evaluation completed successfully")
|
logging.info("Evaluation completed successfully")
|
||||||
|
|
||||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over):
|
||||||
# Create side-by-side figures for denoising visualization
|
# Create side-by-side figures for denoising visualization
|
||||||
@@ -262,15 +248,20 @@ class RTCEvaluator:
|
|||||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
)
|
)
|
||||||
|
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
|
||||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plot ground truth on x1_t axes
|
# Plot ground truth on x1_t axes
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Plot ground truth on x_t axes
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
|
)
|
||||||
|
|
||||||
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
|
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
|
)
|
||||||
|
|
||||||
# Save denoising plots
|
# Save denoising plots
|
||||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||||
@@ -290,7 +281,7 @@ class RTCEvaluator:
|
|||||||
def _save_figure(self, fig, path):
|
def _save_figure(self, fig, path):
|
||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
fig.savefig(path, dpi=150)
|
fig.savefig(path, dpi=150)
|
||||||
logger.info(f"Saved figure to {path}")
|
logging.info(f"Saved figure to {path}")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
|
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, x1t_axs, num_steps):
|
||||||
@@ -372,10 +363,12 @@ def main(cfg: RTCEvalConfig):
|
|||||||
# Set random seed for reproducibility
|
# Set random seed for reproducibility
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
|
|
||||||
logger.info("=" * 80)
|
init_logging()
|
||||||
logger.info("RTC Dataset Evaluation")
|
|
||||||
logger.info(f"Config: {cfg}")
|
logging.info("=" * 80)
|
||||||
logger.info("=" * 80)
|
logging.info("RTC Dataset Evaluation")
|
||||||
|
logging.info(f"Config: {cfg}")
|
||||||
|
logging.info("=" * 80)
|
||||||
|
|
||||||
evaluator = RTCEvaluator(cfg)
|
evaluator = RTCEvaluator(cfg)
|
||||||
evaluator.run_evaluation()
|
evaluator.run_evaluation()
|
||||||
|
|||||||
Reference in New Issue
Block a user