mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
Update images
This commit is contained in:
@@ -39,8 +39,9 @@ Usage:
|
|||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=lerobot/pi05_libero_finetuned \
|
--policy.path=lerobot/pi05_libero_finetuned \
|
||||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
--rtc.execution_horizon=8 \
|
--rtc.execution_horizon=10 \
|
||||||
--device=mps
|
--device=mps
|
||||||
|
--seed=10
|
||||||
|
|
||||||
# Basic usage with pi0.5 policy with cuda device
|
# Basic usage with pi0.5 policy with cuda device
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
@@ -795,16 +796,34 @@ class RTCEvaluator:
|
|||||||
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||||
ax.set_xlim(-0.5, max_len - 0.5)
|
ax.set_xlim(-0.5, max_len - 0.5)
|
||||||
|
|
||||||
# Add legend only to first subplot
|
|
||||||
if dim_idx == 0:
|
|
||||||
ax.legend(loc="best", fontsize=9)
|
|
||||||
|
|
||||||
axes[-1].set_xlabel("Step", fontsize=10)
|
axes[-1].set_xlabel("Step", fontsize=10)
|
||||||
|
|
||||||
|
# Collect legend handles and labels from first subplot
|
||||||
|
handles, labels = axes[0].get_legend_handles_labels()
|
||||||
|
# Remove duplicates while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_handles = []
|
||||||
|
unique_labels = []
|
||||||
|
for handle, label in zip(handles, labels, strict=True):
|
||||||
|
if label not in seen:
|
||||||
|
seen.add(label)
|
||||||
|
unique_handles.append(handle)
|
||||||
|
unique_labels.append(label)
|
||||||
|
|
||||||
|
# Add legend outside the plot area (to the right)
|
||||||
|
fig.legend(
|
||||||
|
unique_handles,
|
||||||
|
unique_labels,
|
||||||
|
loc="center right",
|
||||||
|
fontsize=9,
|
||||||
|
bbox_to_anchor=(1.0, 0.5),
|
||||||
|
framealpha=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
# Save figure
|
# Save figure
|
||||||
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||||
fig.tight_layout()
|
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
|
||||||
fig.savefig(output_path, dpi=150)
|
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||||
logging.info(f"Saved final actions comparison to {output_path}")
|
logging.info(f"Saved final actions comparison to {output_path}")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@@ -825,6 +844,7 @@ class RTCEvaluator:
|
|||||||
axs_corr[:, 1], # Right column for correction
|
axs_corr[:, 1], # Right column for correction
|
||||||
axs_x1t[:, 1], # Right column for x1_t
|
axs_x1t[:, 1], # Right column for x1_t
|
||||||
num_steps,
|
num_steps,
|
||||||
|
add_labels=True, # Add labels for RTC (right column)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._plot_denoising_steps_from_tracker(
|
self._plot_denoising_steps_from_tracker(
|
||||||
@@ -834,6 +854,7 @@ class RTCEvaluator:
|
|||||||
axs_corr[:, 0], # Left column for correction
|
axs_corr[:, 0], # Left column for correction
|
||||||
axs_x1t[:, 0], # Left column for x1_t
|
axs_x1t[:, 0], # Left column for x1_t
|
||||||
num_steps,
|
num_steps,
|
||||||
|
add_labels=False, # No labels for No RTC (left column)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
|
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
|
||||||
@@ -849,15 +870,21 @@ class RTCEvaluator:
|
|||||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot ground truth on x_t axes
|
# Plot ground truth on x_t axes (no labels for left column)
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
|
||||||
)
|
)
|
||||||
|
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add legends outside the plot area for each figure
|
||||||
|
self._add_figure_legend(fig_xt, axs_xt)
|
||||||
|
self._add_figure_legend(fig_vt, axs_vt)
|
||||||
|
self._add_figure_legend(fig_corr, axs_corr)
|
||||||
|
self._add_figure_legend(fig_x1t, axs_x1t)
|
||||||
|
|
||||||
# Save denoising plots
|
# Save denoising plots
|
||||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||||
@@ -875,13 +902,47 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
return fig, axs
|
return fig, axs
|
||||||
|
|
||||||
|
def _add_figure_legend(self, fig, axs):
|
||||||
|
"""Add a legend outside the plot area on the right side.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fig: Matplotlib figure to add legend to
|
||||||
|
axs: Array of axes to collect legend handles from
|
||||||
|
"""
|
||||||
|
# Collect all handles and labels from the first row of axes (right column)
|
||||||
|
handles, labels = axs[0, 1].get_legend_handles_labels()
|
||||||
|
|
||||||
|
# Remove duplicates while preserving order
|
||||||
|
seen = set()
|
||||||
|
unique_handles = []
|
||||||
|
unique_labels = []
|
||||||
|
for handle, label in zip(handles, labels, strict=True):
|
||||||
|
if label not in seen:
|
||||||
|
seen.add(label)
|
||||||
|
unique_handles.append(handle)
|
||||||
|
unique_labels.append(label)
|
||||||
|
|
||||||
|
# Add legend outside the plot area (to the right, close to charts)
|
||||||
|
if unique_handles:
|
||||||
|
fig.legend(
|
||||||
|
unique_handles,
|
||||||
|
unique_labels,
|
||||||
|
loc="center left",
|
||||||
|
fontsize=8,
|
||||||
|
bbox_to_anchor=(0.87, 0.5),
|
||||||
|
framealpha=0.9,
|
||||||
|
ncol=1,
|
||||||
|
)
|
||||||
|
|
||||||
def _save_figure(self, fig, path):
|
def _save_figure(self, fig, path):
|
||||||
fig.tight_layout()
|
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
|
||||||
fig.savefig(path, dpi=150)
|
fig.savefig(path, dpi=150, bbox_inches="tight")
|
||||||
logging.info(f"Saved figure to {path}")
|
logging.info(f"Saved figure to {path}")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
def _plot_denoising_steps_from_tracker(self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps):
|
def _plot_denoising_steps_from_tracker(
|
||||||
|
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
|
||||||
|
):
|
||||||
"""Plot denoising steps from tracker data.
|
"""Plot denoising steps from tracker data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -891,6 +952,7 @@ class RTCEvaluator:
|
|||||||
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
|
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
|
||||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||||
num_steps: Total number of denoising steps for colormap
|
num_steps: Total number of denoising steps for colormap
|
||||||
|
add_labels: Whether to add legend labels for the plots
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
@@ -905,17 +967,18 @@ class RTCEvaluator:
|
|||||||
|
|
||||||
for step_idx, debug_step in enumerate(debug_steps):
|
for step_idx, debug_step in enumerate(debug_steps):
|
||||||
color = colors[step_idx % len(colors)]
|
color = colors[step_idx % len(colors)]
|
||||||
|
label = f"Step {step_idx}" if add_labels else None
|
||||||
|
|
||||||
# Plot x_t
|
# Plot x_t
|
||||||
if debug_step.x_t is not None:
|
if debug_step.x_t is not None:
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
xt_axs, debug_step.x_t, start_from=0, color=color, label=f"Step {step_idx}"
|
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot v_t
|
# Plot v_t
|
||||||
if debug_step.v_t is not None:
|
if debug_step.v_t is not None:
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
vt_axs, debug_step.v_t, start_from=0, color=color, label=f"Step {step_idx}"
|
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot correction on separate axes
|
# Plot correction on separate axes
|
||||||
@@ -925,17 +988,18 @@ class RTCEvaluator:
|
|||||||
debug_step.correction,
|
debug_step.correction,
|
||||||
start_from=0,
|
start_from=0,
|
||||||
color=color,
|
color=color,
|
||||||
label=f"Step {step_idx}",
|
label=label,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot x1_t (predicted state)
|
# Plot x1_t (predicted state)
|
||||||
if x1t_axs is not None and debug_step.x1_t is not None:
|
if x1t_axs is not None and debug_step.x1_t is not None:
|
||||||
|
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
|
||||||
RTCDebugVisualizer.plot_waypoints(
|
RTCDebugVisualizer.plot_waypoints(
|
||||||
x1t_axs,
|
x1t_axs,
|
||||||
debug_step.x1_t,
|
debug_step.x1_t,
|
||||||
start_from=0,
|
start_from=0,
|
||||||
color=color,
|
color=color,
|
||||||
label=f"x1_t Step {step_idx}",
|
label=x1t_label,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Plot error in orange dashed
|
# Plot error in orange dashed
|
||||||
@@ -947,6 +1011,7 @@ class RTCEvaluator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_dims = min(error_chunk.shape[-1], 6)
|
num_dims = min(error_chunk.shape[-1], 6)
|
||||||
|
error_label = f"error Step {step_idx}" if add_labels else None
|
||||||
for j in range(num_dims):
|
for j in range(num_dims):
|
||||||
x1t_axs[j].plot(
|
x1t_axs[j].plot(
|
||||||
np.arange(0, error_chunk.shape[0]),
|
np.arange(0, error_chunk.shape[0]),
|
||||||
@@ -954,7 +1019,7 @@ class RTCEvaluator:
|
|||||||
color="orange",
|
color="orange",
|
||||||
linestyle="--",
|
linestyle="--",
|
||||||
alpha=0.7,
|
alpha=0.7,
|
||||||
label=f"error Step {step_idx}",
|
label=error_label,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Recalculate axis limits after plotting to ensure proper scaling
|
# Recalculate axis limits after plotting to ensure proper scaling
|
||||||
|
|||||||
@@ -111,7 +111,3 @@ class RTCDebugVisualizer:
|
|||||||
if not ax.yaxis.get_label().get_text():
|
if not ax.yaxis.get_label().get_text():
|
||||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
# Add legend if label provided and this is the first dimension
|
|
||||||
if label and dim_idx == 0:
|
|
||||||
ax.legend(loc="best", fontsize=8)
|
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.3 MiB After Width: | Height: | Size: 1.3 MiB |
Reference in New Issue
Block a user