diff --git a/examples/openarms/evaluate_ee.py b/examples/openarms/evaluate_ee.py index 9f87f403a..ce6438e3e 100644 --- a/examples/openarms/evaluate_ee.py +++ b/examples/openarms/evaluate_ee.py @@ -35,12 +35,13 @@ import numpy as np import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.train import TrainPipelineConfig from lerobot.utils.relative_actions import ( convert_state_to_relative, convert_from_relative_actions, PerTimestepNormalizer, ) -from lerobot.configs.policies import PreTrainedConfig from lerobot.model.kinematics import RobotKinematics from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline @@ -71,11 +72,6 @@ FPS = 30 EPISODE_TIME_SEC = 1000 RESET_TIME_SEC = 60 -# UMI-style relative action/state (must match training config) -USE_RELATIVE_ACTIONS = False # If True, policy outputs relative EE actions -USE_RELATIVE_STATE = False # If True, convert state to relative before policy input -RELATIVE_STATS_PATH = None # Path to relative_stats.pt (for per-timestep normalization) - # Robot CAN interfaces FOLLOWER_LEFT_PORT = "can0" FOLLOWER_RIGHT_PORT = "can1" @@ -102,6 +98,51 @@ LEFT_URDF_JOINTS = [f"openarm_left_joint{i}" for i in range(1, 8)] RIGHT_URDF_JOINTS = [f"openarm_right_joint{i}" for i in range(1, 8)] +def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool, bool]: + """Auto-detect relative action/state settings and load normalizer from checkpoint.""" + model_path = Path(model_path) if isinstance(model_path, str) else model_path + normalizer = None + use_relative_actions = False + use_relative_state = False + + # Try local path first + if model_path.exists(): + stats_path = model_path / "relative_stats.pt" + if stats_path.exists(): + normalizer = PerTimestepNormalizer.load(stats_path) + use_relative_actions = True + print(f" Loaded per-timestep stats from: {stats_path}") + + config_path = model_path / "train_config.json" + if config_path.exists(): + cfg = TrainPipelineConfig.from_pretrained(model_path) + use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions) + use_relative_state = getattr(cfg, "use_relative_state", False) + else: + # Try hub + try: + from huggingface_hub import hf_hub_download + try: + stats_file = hf_hub_download(repo_id=str(model_path), filename="relative_stats.pt") + normalizer = PerTimestepNormalizer.load(stats_file) + use_relative_actions = True + print(" Loaded per-timestep stats from hub") + except Exception: + pass # No stats file means no relative actions + + try: + config_file = hf_hub_download(repo_id=str(model_path), filename="train_config.json") + cfg = TrainPipelineConfig.from_pretrained(Path(config_file).parent) + use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions) + use_relative_state = getattr(cfg, "use_relative_state", False) + except Exception: + pass + except Exception as e: + print(f" Warning: Could not load relative config: {e}") + + return normalizer, use_relative_actions, use_relative_state + + def build_kinematics_pipelines(urdf_path: str, left_ee_frame: str, right_ee_frame: str): """Build FK and IK pipelines for bimanual robot.""" left_kinematics = RobotKinematics( @@ -337,11 +378,6 @@ def main(): print(f"Task: {TASK_DESCRIPTION}") print(f"Episodes: {NUM_EPISODES}") print(f"Episode Duration: {EPISODE_TIME_SEC}s") - print(f"\nUMI-style relative mode:") - print(f" Relative actions: {USE_RELATIVE_ACTIONS}") - print(f" Relative state: {USE_RELATIVE_STATE}") - if RELATIVE_STATS_PATH: - print(f" Stats path: {RELATIVE_STATS_PATH}") print("=" * 70) # Resolve URDF path @@ -411,15 +447,13 @@ def main(): ) print(" Policy loaded") - # Load relative action normalizer if using relative actions - relative_normalizer = None - if USE_RELATIVE_ACTIONS and RELATIVE_STATS_PATH: - stats_path = Path(RELATIVE_STATS_PATH) - if stats_path.exists(): - print(f" Loading relative stats from: {stats_path}") - relative_normalizer = PerTimestepNormalizer.load(stats_path) - else: - print(f" WARNING: Relative stats not found at {stats_path}") + # Auto-detect relative action/state settings from checkpoint + relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID) + + mode = "absolute" + if use_relative_actions: + mode = "relative actions + state" if use_relative_state else "relative actions only" + print(f" Mode: {mode}") # Initialize keyboard listener print("\n[4/4] Starting evaluation...") @@ -449,8 +483,8 @@ def main(): fps=FPS, duration_s=EPISODE_TIME_SEC, events=events, - use_relative_actions=USE_RELATIVE_ACTIONS, - use_relative_state=USE_RELATIVE_STATE, + use_relative_actions=use_relative_actions, + use_relative_state=use_relative_state, relative_normalizer=relative_normalizer, )