mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
auto detect mode and stats
This commit is contained in:
@@ -35,12 +35,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.utils.relative_actions import (
|
from lerobot.utils.relative_actions import (
|
||||||
convert_state_to_relative,
|
convert_state_to_relative,
|
||||||
convert_from_relative_actions,
|
convert_from_relative_actions,
|
||||||
PerTimestepNormalizer,
|
PerTimestepNormalizer,
|
||||||
)
|
)
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||||
@@ -71,11 +72,6 @@ FPS = 30
|
|||||||
EPISODE_TIME_SEC = 1000
|
EPISODE_TIME_SEC = 1000
|
||||||
RESET_TIME_SEC = 60
|
RESET_TIME_SEC = 60
|
||||||
|
|
||||||
# UMI-style relative action/state (must match training config)
|
|
||||||
USE_RELATIVE_ACTIONS = False # If True, policy outputs relative EE actions
|
|
||||||
USE_RELATIVE_STATE = False # If True, convert state to relative before policy input
|
|
||||||
RELATIVE_STATS_PATH = None # Path to relative_stats.pt (for per-timestep normalization)
|
|
||||||
|
|
||||||
# Robot CAN interfaces
|
# Robot CAN interfaces
|
||||||
FOLLOWER_LEFT_PORT = "can0"
|
FOLLOWER_LEFT_PORT = "can0"
|
||||||
FOLLOWER_RIGHT_PORT = "can1"
|
FOLLOWER_RIGHT_PORT = "can1"
|
||||||
@@ -102,6 +98,51 @@ LEFT_URDF_JOINTS = [f"openarm_left_joint{i}" for i in range(1, 8)]
|
|||||||
RIGHT_URDF_JOINTS = [f"openarm_right_joint{i}" for i in range(1, 8)]
|
RIGHT_URDF_JOINTS = [f"openarm_right_joint{i}" for i in range(1, 8)]
|
||||||
|
|
||||||
|
|
||||||
|
def load_relative_config(model_path: Path | str) -> tuple[PerTimestepNormalizer | None, bool, bool]:
|
||||||
|
"""Auto-detect relative action/state settings and load normalizer from checkpoint."""
|
||||||
|
model_path = Path(model_path) if isinstance(model_path, str) else model_path
|
||||||
|
normalizer = None
|
||||||
|
use_relative_actions = False
|
||||||
|
use_relative_state = False
|
||||||
|
|
||||||
|
# Try local path first
|
||||||
|
if model_path.exists():
|
||||||
|
stats_path = model_path / "relative_stats.pt"
|
||||||
|
if stats_path.exists():
|
||||||
|
normalizer = PerTimestepNormalizer.load(stats_path)
|
||||||
|
use_relative_actions = True
|
||||||
|
print(f" Loaded per-timestep stats from: {stats_path}")
|
||||||
|
|
||||||
|
config_path = model_path / "train_config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
cfg = TrainPipelineConfig.from_pretrained(model_path)
|
||||||
|
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
|
||||||
|
use_relative_state = getattr(cfg, "use_relative_state", False)
|
||||||
|
else:
|
||||||
|
# Try hub
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
try:
|
||||||
|
stats_file = hf_hub_download(repo_id=str(model_path), filename="relative_stats.pt")
|
||||||
|
normalizer = PerTimestepNormalizer.load(stats_file)
|
||||||
|
use_relative_actions = True
|
||||||
|
print(" Loaded per-timestep stats from hub")
|
||||||
|
except Exception:
|
||||||
|
pass # No stats file means no relative actions
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_file = hf_hub_download(repo_id=str(model_path), filename="train_config.json")
|
||||||
|
cfg = TrainPipelineConfig.from_pretrained(Path(config_file).parent)
|
||||||
|
use_relative_actions = getattr(cfg, "use_relative_actions", use_relative_actions)
|
||||||
|
use_relative_state = getattr(cfg, "use_relative_state", False)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Warning: Could not load relative config: {e}")
|
||||||
|
|
||||||
|
return normalizer, use_relative_actions, use_relative_state
|
||||||
|
|
||||||
|
|
||||||
def build_kinematics_pipelines(urdf_path: str, left_ee_frame: str, right_ee_frame: str):
|
def build_kinematics_pipelines(urdf_path: str, left_ee_frame: str, right_ee_frame: str):
|
||||||
"""Build FK and IK pipelines for bimanual robot."""
|
"""Build FK and IK pipelines for bimanual robot."""
|
||||||
left_kinematics = RobotKinematics(
|
left_kinematics = RobotKinematics(
|
||||||
@@ -337,11 +378,6 @@ def main():
|
|||||||
print(f"Task: {TASK_DESCRIPTION}")
|
print(f"Task: {TASK_DESCRIPTION}")
|
||||||
print(f"Episodes: {NUM_EPISODES}")
|
print(f"Episodes: {NUM_EPISODES}")
|
||||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||||
print(f"\nUMI-style relative mode:")
|
|
||||||
print(f" Relative actions: {USE_RELATIVE_ACTIONS}")
|
|
||||||
print(f" Relative state: {USE_RELATIVE_STATE}")
|
|
||||||
if RELATIVE_STATS_PATH:
|
|
||||||
print(f" Stats path: {RELATIVE_STATS_PATH}")
|
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
|
|
||||||
# Resolve URDF path
|
# Resolve URDF path
|
||||||
@@ -411,15 +447,13 @@ def main():
|
|||||||
)
|
)
|
||||||
print(" Policy loaded")
|
print(" Policy loaded")
|
||||||
|
|
||||||
# Load relative action normalizer if using relative actions
|
# Auto-detect relative action/state settings from checkpoint
|
||||||
relative_normalizer = None
|
relative_normalizer, use_relative_actions, use_relative_state = load_relative_config(HF_MODEL_ID)
|
||||||
if USE_RELATIVE_ACTIONS and RELATIVE_STATS_PATH:
|
|
||||||
stats_path = Path(RELATIVE_STATS_PATH)
|
mode = "absolute"
|
||||||
if stats_path.exists():
|
if use_relative_actions:
|
||||||
print(f" Loading relative stats from: {stats_path}")
|
mode = "relative actions + state" if use_relative_state else "relative actions only"
|
||||||
relative_normalizer = PerTimestepNormalizer.load(stats_path)
|
print(f" Mode: {mode}")
|
||||||
else:
|
|
||||||
print(f" WARNING: Relative stats not found at {stats_path}")
|
|
||||||
|
|
||||||
# Initialize keyboard listener
|
# Initialize keyboard listener
|
||||||
print("\n[4/4] Starting evaluation...")
|
print("\n[4/4] Starting evaluation...")
|
||||||
@@ -449,8 +483,8 @@ def main():
|
|||||||
fps=FPS,
|
fps=FPS,
|
||||||
duration_s=EPISODE_TIME_SEC,
|
duration_s=EPISODE_TIME_SEC,
|
||||||
events=events,
|
events=events,
|
||||||
use_relative_actions=USE_RELATIVE_ACTIONS,
|
use_relative_actions=use_relative_actions,
|
||||||
use_relative_state=USE_RELATIVE_STATE,
|
use_relative_state=use_relative_state,
|
||||||
relative_normalizer=relative_normalizer,
|
relative_normalizer=relative_normalizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user