mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
auto detect mode and stats
This commit is contained in:
@@ -35,12 +35,13 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.utils.relative_actions import (
|
||||
convert_state_to_relative,
|
||||
convert_from_relative_actions,
|
||||
PerTimestepNormalizer,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
@@ -71,11 +72,6 @@ FPS = 30
|
||||
EPISODE_TIME_SEC = 1000
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# UMI-style relative action/state (must match training config)
|
||||
USE_RELATIVE_ACTIONS = False # If True, policy outputs relative EE actions
|
||||
USE_RELATIVE_STATE = False # If True, convert state to relative before policy input
|
||||
RELATIVE_STATS_PATH = None # Path to relative_stats.pt (for per-timestep normalization)
|
||||
|
||||
# Robot CAN interfaces
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
@@ -102,6 +98,51 @@ LEFT_URDF_JOINTS = [f"openarm_left_joint{i}" for i in range(1, 8)]
|
||||
RIGHT_URDF_JOINTS = [f"openarm_right_joint{i}" for i in range(1, 8)]
|
||||
|
||||
|
||||
def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool, bool]:
|
||||
"""Auto-detect relative action/state settings and load normalizer from checkpoint."""
|
||||
model_path = Path(model_path) if isinstance(model_path, str) else model_path
|
||||
normalizer = None
|
||||
use_relative_actions = False
|
||||
use_relative_state = False
|
||||
|
||||
# Try local path first
|
||||
if model_path.exists():
|
||||
stats_path = model_path / "relative_stats.pt"
|
||||
if stats_path.exists():
|
||||
normalizer = PerTimestepNormalizer.load(stats_path)
|
||||
use_relative_actions = True
|
||||
print(f" Loaded per-timestep stats from: {stats_path}")
|
||||
|
||||
config_path = model_path / "train_config.json"
|
||||
if config_path.exists():
|
||||
cfg = TrainPipelineConfig.from_pretrained(model_path)
|
||||
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
|
||||
use_relative_state = getattr(cfg, "use_relative_state", False)
|
||||
else:
|
||||
# Try hub
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
try:
|
||||
stats_file = hf_hub_download(repo_id=str(model_path), filename="relative_stats.pt")
|
||||
normalizer = PerTimestepNormalizer.load(stats_file)
|
||||
use_relative_actions = True
|
||||
print(" Loaded per-timestep stats from hub")
|
||||
except Exception:
|
||||
pass # No stats file means no relative actions
|
||||
|
||||
try:
|
||||
config_file = hf_hub_download(repo_id=str(model_path), filename="train_config.json")
|
||||
cfg = TrainPipelineConfig.from_pretrained(Path(config_file).parent)
|
||||
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
|
||||
use_relative_state = getattr(cfg, "use_relative_state", False)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not load relative config: {e}")
|
||||
|
||||
return normalizer, use_relative_actions, use_relative_state
|
||||
|
||||
|
||||
def build_kinematics_pipelines(urdf_path: str, left_ee_frame: str, right_ee_frame: str):
|
||||
"""Build FK and IK pipelines for bimanual robot."""
|
||||
left_kinematics = RobotKinematics(
|
||||
@@ -337,11 +378,6 @@ def main():
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"\nUMI-style relative mode:")
|
||||
print(f" Relative actions: {USE_RELATIVE_ACTIONS}")
|
||||
print(f" Relative state: {USE_RELATIVE_STATE}")
|
||||
if RELATIVE_STATS_PATH:
|
||||
print(f" Stats path: {RELATIVE_STATS_PATH}")
|
||||
print("=" * 70)
|
||||
|
||||
# Resolve URDF path
|
||||
@@ -411,15 +447,13 @@ def main():
|
||||
)
|
||||
print(" Policy loaded")
|
||||
|
||||
# Load relative action normalizer if using relative actions
|
||||
relative_normalizer = None
|
||||
if USE_RELATIVE_ACTIONS and RELATIVE_STATS_PATH:
|
||||
stats_path = Path(RELATIVE_STATS_PATH)
|
||||
if stats_path.exists():
|
||||
print(f" Loading relative stats from: {stats_path}")
|
||||
relative_normalizer = PerTimestepNormalizer.load(stats_path)
|
||||
else:
|
||||
print(f" WARNING: Relative stats not found at {stats_path}")
|
||||
# Auto-detect relative action/state settings from checkpoint
|
||||
relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)
|
||||
|
||||
mode = "absolute"
|
||||
if use_relative_actions:
|
||||
mode = "relative actions + state" if use_relative_state else "relative actions only"
|
||||
print(f" Mode: {mode}")
|
||||
|
||||
# Initialize keyboard listener
|
||||
print("\n[4/4] Starting evaluation...")
|
||||
@@ -449,8 +483,8 @@ def main():
|
||||
fps=FPS,
|
||||
duration_s=EPISODE_TIME_SEC,
|
||||
events=events,
|
||||
use_relative_actions=USE_RELATIVE_ACTIONS,
|
||||
use_relative_state=USE_RELATIVE_STATE,
|
||||
use_relative_actions=use_relative_actions,
|
||||
use_relative_state=use_relative_state,
|
||||
relative_normalizer=relative_normalizer,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user