mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
fixup! Fixup eval with real robot
This commit is contained in:
@@ -83,10 +83,17 @@ import os
|
|||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
MATPLOTLIB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MATPLOTLIB_AVAILABLE = False
|
||||||
|
plt = None
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -114,6 +121,16 @@ def set_seed(seed: int):
|
|||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def _check_matplotlib_available():
|
||||||
|
"""Check if matplotlib is available, raise helpful error if not."""
|
||||||
|
if not MATPLOTLIB_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"matplotlib is required for RTC debug visualizations. "
|
||||||
|
"Please install it by running:\n"
|
||||||
|
" uv pip install -e '.[matplotlib-dep]'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RTCEvalConfig(HubMixin):
|
class RTCEvalConfig(HubMixin):
|
||||||
"""Configuration for RTC evaluation."""
|
"""Configuration for RTC evaluation."""
|
||||||
@@ -686,6 +703,8 @@ class RTCEvaluator:
|
|||||||
no_rtc_actions: Final actions from non-RTC policy
|
no_rtc_actions: Final actions from non-RTC policy
|
||||||
prev_chunk_left_over: Previous chunk used as ground truth
|
prev_chunk_left_over: Previous chunk used as ground truth
|
||||||
"""
|
"""
|
||||||
|
_check_matplotlib_available()
|
||||||
|
|
||||||
# Remove batch dimension if present
|
# Remove batch dimension if present
|
||||||
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||||
no_rtc_actions_plot = (
|
no_rtc_actions_plot = (
|
||||||
@@ -776,6 +795,8 @@ class RTCEvaluator:
|
|||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||||
|
_check_matplotlib_available()
|
||||||
|
|
||||||
# Create side-by-side figures for denoising visualization
|
# Create side-by-side figures for denoising visualization
|
||||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||||
|
|||||||
@@ -573,7 +573,7 @@ def demo_cli(cfg: RTCDemoConfig):
|
|||||||
# The processor won't be created
|
# The processor won't be created
|
||||||
policy.init_rtc_processor()
|
policy.init_rtc_processor()
|
||||||
|
|
||||||
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||||
|
|
||||||
policy = policy.to(cfg.device)
|
policy = policy.to(cfg.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|||||||
+3
-2
@@ -98,6 +98,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
|||||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||||
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
transformers-dep = ["transformers>=4.53.0,<5.0.0"]
|
||||||
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb)
|
||||||
|
matplotlib-dep = ["matplotlib>=3.10.3,<4.0.0"]
|
||||||
|
|
||||||
# Motors
|
# Motors
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||||
@@ -132,10 +133,10 @@ groot = [
|
|||||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||||
|
|
||||||
# Features
|
# Features
|
||||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"]
|
async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"]
|
||||||
|
|
||||||
# Development
|
# Development
|
||||||
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "matplotlib>=3.10.3,<4.0.0"]
|
dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||||
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"]
|
||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user