auto detect mode and stats

This commit is contained in:
Pepijn
2026-01-08 16:34:46 +01:00
parent c720a4a347
commit cf75b75474
+56 -22
View File
@@ -35,12 +35,13 @@ import numpy as np
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.utils.relative_actions import (
convert_state_to_relative,
convert_from_relative_actions,
PerTimestepNormalizer,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
@@ -71,11 +72,6 @@ FPS = 30
EPISODE_TIME_SEC = 1000
RESET_TIME_SEC = 60
# UMI-style relative action/state (must match training config)
USE_RELATIVE_ACTIONS = False # If True, policy outputs relative EE actions
USE_RELATIVE_STATE = False # If True, convert state to relative before policy input
RELATIVE_STATS_PATH = None # Path to relative_stats.pt (for per-timestep normalization)
# Robot CAN interfaces
FOLLOWER_LEFT_PORT = "can0"
FOLLOWER_RIGHT_PORT = "can1"
@@ -102,6 +98,51 @@ LEFT_URDF_JOINTS = [f"openarm_left_joint{i}" for i in range(1, 8)]
RIGHT_URDF_JOINTS = [f"openarm_right_joint{i}" for i in range(1, 8)]
def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool, bool]:
"""Auto-detect relative action/state settings and load normalizer from checkpoint."""
model_path = Path(model_path) if isinstance(model_path, str) else model_path
normalizer = None
use_relative_actions = False
use_relative_state = False
# Try local path first
if model_path.exists():
stats_path = model_path / "relative_stats.pt"
if stats_path.exists():
normalizer = PerTimestepNormalizer.load(stats_path)
use_relative_actions = True
print(f" Loaded per-timestep stats from: {stats_path}")
config_path = model_path / "train_config.json"
if config_path.exists():
cfg = TrainPipelineConfig.from_pretrained(model_path)
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
use_relative_state = getattr(cfg, "use_relative_state", False)
else:
# Try hub
try:
from huggingface_hub import hf_hub_download
try:
stats_file = hf_hub_download(repo_id=str(model_path), filename="relative_stats.pt")
normalizer = PerTimestepNormalizer.load(stats_file)
use_relative_actions = True
print(" Loaded per-timestep stats from hub")
except Exception:
pass # No stats file means no relative actions
try:
config_file = hf_hub_download(repo_id=str(model_path), filename="train_config.json")
cfg = TrainPipelineConfig.from_pretrained(Path(config_file).parent)
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
use_relative_state = getattr(cfg, "use_relative_state", False)
except Exception:
pass
except Exception as e:
print(f" Warning: Could not load relative config: {e}")
return normalizer, use_relative_actions, use_relative_state
def build_kinematics_pipelines(urdf_path: str, left_ee_frame: str, right_ee_frame: str):
"""Build FK and IK pipelines for bimanual robot."""
left_kinematics = RobotKinematics(
@@ -337,11 +378,6 @@ def main():
print(f"Task: {TASK_DESCRIPTION}")
print(f"Episodes: {NUM_EPISODES}")
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
print(f"\nUMI-style relative mode:")
print(f" Relative actions: {USE_RELATIVE_ACTIONS}")
print(f" Relative state: {USE_RELATIVE_STATE}")
if RELATIVE_STATS_PATH:
print(f" Stats path: {RELATIVE_STATS_PATH}")
print("=" * 70)
# Resolve URDF path
@@ -411,15 +447,13 @@ def main():
)
print(" Policy loaded")
# Load relative action normalizer if using relative actions
relative_normalizer = None
if USE_RELATIVE_ACTIONS and RELATIVE_STATS_PATH:
stats_path = Path(RELATIVE_STATS_PATH)
if stats_path.exists():
print(f" Loading relative stats from: {stats_path}")
relative_normalizer = PerTimestepNormalizer.load(stats_path)
else:
print(f" WARNING: Relative stats not found at {stats_path}")
# Auto-detect relative action/state settings from checkpoint
relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)
mode = "absolute"
if use_relative_actions:
mode = "relative actions + state" if use_relative_state else "relative actions only"
print(f" Mode: {mode}")
# Initialize keyboard listener
print("\n[4/4] Starting evaluation...")
@@ -449,8 +483,8 @@ def main():
fps=FPS,
duration_s=EPISODE_TIME_SEC,
events=events,
use_relative_actions=USE_RELATIVE_ACTIONS,
use_relative_state=USE_RELATIVE_STATE,
use_relative_actions=use_relative_actions,
use_relative_state=use_relative_state,
relative_normalizer=relative_normalizer,
)