This commit is contained in:
Pepijn
2026-01-09 09:56:14 +01:00
parent 8e430f323f
commit 33f84fe0ec
+382 -211
View File
@@ -15,11 +15,11 @@
# limitations under the License. # limitations under the License.
""" """
OpenArms Policy Evaluation with Interpolation OpenArms Policy Evaluation with RTC and Interpolation
Evaluates a trained policy with smooth action interpolation: Evaluates a trained policy with:
- Decoupled camera capture (CAMERA_FPS) from robot control (ROBOT_FPS) - RTC (Real-Time Chunking) for async inference - decouples policy from robot loop
- Speed multiplier to execute actions faster than training - Smooth action interpolation for high-frequency robot control
- Velocity feedforward for smoother tracking - Velocity feedforward for smoother tracking
- Adjustable PID gains - Adjustable PID gains
@@ -27,27 +27,41 @@ Example usage:
python examples/openarms/evaluate_interpolation.py python examples/openarms/evaluate_interpolation.py
""" """
import logging
import math
import sys
import time import time
import traceback
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from threading import Event, Lock, Thread
import numpy as np import numpy as np
import torch
from torch import Tensor
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.datasets.utils import combine_feature_dicts from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor import make_default_processors from lerobot.processor import make_default_processors
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
from lerobot.utils.control_utils import init_keyboard_listener, predict_action from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say, get_safe_torch_device from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ======================== MODEL & TASK CONFIG ======================== # ======================== MODEL & TASK CONFIG ========================
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
@@ -57,81 +71,107 @@ TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task
# ======================== TIMING CONFIG ======================== # ======================== TIMING CONFIG ========================
CAMERA_FPS = 30 # Camera hardware limit (fixed) CAMERA_FPS = 30 # Camera hardware limit (fixed)
POLICY_FPS = 30 # What the policy was trained with POLICY_FPS = 30 # What the policy was trained with
SPEED_MULTIPLIER = 1.2 # Execute actions faster (1.0 = normal, 1.2 = 20% faster)
ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation) ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation)
# Derived values
EFFECTIVE_POLICY_FPS = int(POLICY_FPS * SPEED_MULTIPLIER) # How fast we consume actions (36Hz at 1.2x)
NUM_EPISODES = 1 NUM_EPISODES = 1
EPISODE_TIME_SEC = 300 EPISODE_TIME_SEC = 300
RESET_TIME_SEC = 60 RESET_TIME_SEC = 60
# ======================== RTC CONFIG ========================
RTC_ENABLED = True
RTC_EXECUTION_HORIZON = 20
RTC_MAX_GUIDANCE_WEIGHT = 5.0
ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS = 30 # Should be > inference_delay + execution_horizon
# ======================== PID TUNING ======================== # ======================== PID TUNING ========================
# Set to None to use robot config defaults CUSTOM_KP_SCALE = 1.0 # Scale factor for position gain (0.5-1.0, lower = smoother)
CUSTOM_KP_SCALE = 0.7 # Scale factor for position gain (0.5-1.0, lower = smoother) CUSTOM_KD_SCALE = 1.0 # Scale factor for damping gain (1.0-2.0, higher = less overshoot)
CUSTOM_KD_SCALE = 1.3 # Scale factor for damping gain (1.0-2.0, higher = less overshoot) USE_VELOCITY_FEEDFORWARD = False # Enable velocity feedforward for smoother tracking
USE_VELOCITY_FEEDFORWARD = True # Enable velocity feedforward for smoother tracking
# ======================== ROBOT CONFIG ======================== # ======================== ROBOT CONFIG ========================
FOLLOWER_LEFT_PORT = "can0" FOLLOWER_LEFT_PORT = "can0"
FOLLOWER_RIGHT_PORT = "can1" FOLLOWER_RIGHT_PORT = "can1"
USE_LEADER_FOR_RESETS = True USE_LEADER_FOR_RESETS = False
LEADER_LEFT_PORT = "can2" LEADER_LEFT_PORT = "can2"
LEADER_RIGHT_PORT = "can3" LEADER_RIGHT_PORT = "can3"
# Camera config uses CAMERA_FPS (hardware limit) DEVICE = "cuda"
CAMERA_CONFIG = { CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=CAMERA_FPS), "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=1280, height=720, fps=CAMERA_FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=CAMERA_FPS), "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=1280, height=720, fps=CAMERA_FPS),
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=CAMERA_FPS), "base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=CAMERA_FPS),
} }
class RobotWrapper:
"""Thread-safe wrapper for robot operations."""
def __init__(self, robot: OpenArmsFollower):
self.robot = robot
self.lock = Lock()
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: dict, **kwargs) -> None:
with self.lock:
self.robot.send_action(action, **kwargs)
@property
def observation_features(self) -> dict:
with self.lock:
return self.robot.observation_features
@property
def action_features(self) -> dict:
with self.lock:
return self.robot.action_features
@property
def name(self) -> str:
return self.robot.name
class ActionInterpolator: class ActionInterpolator:
"""Interpolate between policy actions for smoother robot control with velocity estimation.""" """Interpolate between consecutive actions for smoother robot control."""
def __init__(self, effective_policy_fps: int, robot_fps: int): def __init__(self, policy_fps: int, robot_fps: int):
self.effective_policy_fps = effective_policy_fps self.policy_fps = policy_fps
self.robot_fps = robot_fps self.robot_fps = robot_fps
self.substeps_per_policy_step = robot_fps / effective_policy_fps self.substeps_per_policy_step = robot_fps / policy_fps
self.prev_action: dict | None = None self.prev_action: Tensor | None = None
self.curr_action: dict | None = None self.curr_action: Tensor | None = None
self.substep = 0 self.substep = 0
self.last_interpolated: dict | None = None self.last_interpolated: Tensor | None = None
def update(self, new_action: dict) -> None: def update(self, new_action: Tensor) -> None:
self.prev_action = self.curr_action self.prev_action = self.curr_action
self.curr_action = new_action self.curr_action = new_action
self.substep = 0 self.substep = 0
def get_interpolated_action(self) -> tuple[dict | None, dict | None]: def get_interpolated_action(self) -> tuple[Tensor | None, Tensor | None]:
"""Returns (interpolated_position, estimated_velocity_deg_per_sec)""" """Returns (interpolated_action, estimated_velocity)"""
if self.curr_action is None: if self.curr_action is None:
return None, None return None, None
if self.prev_action is None: if self.prev_action is None:
self.last_interpolated = self.curr_action.copy() self.last_interpolated = self.curr_action.clone()
return self.curr_action, {k: 0.0 for k in self.curr_action} return self.curr_action, torch.zeros_like(self.curr_action)
t = min(self.substep / self.substeps_per_policy_step, 1.0) t = min(self.substep / self.substeps_per_policy_step, 1.0)
self.substep += 1 self.substep += 1
interpolated = {} interpolated = self.prev_action * (1 - t) + self.curr_action * t
velocity = {}
dt = 1.0 / self.robot_fps dt = 1.0 / self.robot_fps
if self.last_interpolated is not None:
velocity = (interpolated - self.last_interpolated) / dt
else:
velocity = (self.curr_action - self.prev_action) * self.policy_fps
for key in self.curr_action: self.last_interpolated = interpolated.clone()
prev = self.prev_action.get(key, self.curr_action[key])
curr = self.curr_action[key]
interpolated[key] = prev * (1 - t) + curr * t
if self.last_interpolated is not None and key in self.last_interpolated:
velocity[key] = (interpolated[key] - self.last_interpolated[key]) / dt
else:
velocity[key] = (curr - prev) * self.effective_policy_fps
self.last_interpolated = interpolated.copy()
return interpolated, velocity return interpolated, velocity
def reset(self): def reset(self):
@@ -175,160 +215,230 @@ class HzTracker:
self.last_print_time = 0 self.last_print_time = 0
def interpolated_eval_loop( def get_actions_thread(
robot,
policy, policy,
preprocessor, robot: RobotWrapper,
postprocessor,
robot_observation_processor, robot_observation_processor,
action_queue: ActionQueue,
shutdown_event: Event,
episode_active: Event,
rtc_config: RTCConfig,
policy_fps: int,
task: str,
pretrained_path: str,
device: str,
):
"""Thread function to asynchronously generate action chunks from the policy."""
try:
logger.info("[GET_ACTIONS] Starting action generation thread")
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / policy_fps
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
policy_device = device
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {pretrained_path}")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
pretrained_path=pretrained_path,
dataset_stats=None,
preprocessor_overrides={"device_processor": {"device": device}},
)
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
get_actions_threshold = ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS if rtc_config.enabled else 0
while not shutdown_event.is_set():
if not episode_active.is_set():
time.sleep(0.01)
continue
if action_queue.qsize() <= get_actions_threshold:
current_time = time.perf_counter()
action_index_before_inference = action_queue.get_action_index()
prev_actions = action_queue.get_left_over()
inference_latency = latency_tracker.max()
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
obs_with_policy_features = build_dataset_frame(
hw_features, obs_processed, prefix="observation"
)
for name in obs_with_policy_features:
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
if "image" in name:
obs_with_policy_features[name] = (
obs_with_policy_features[name].type(torch.float32) / 255
)
obs_with_policy_features[name] = (
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
)
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
obs_with_policy_features["task"] = [task]
obs_with_policy_features["robot_type"] = robot.name
preprocessed_obs = preprocessor(obs_with_policy_features)
actions = policy.predict_action_chunk(
preprocessed_obs,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
original_actions = actions.squeeze(0).clone()
postprocessed_actions = postprocessor(actions).squeeze(0)
new_latency = time.perf_counter() - current_time
new_delay = math.ceil(new_latency / time_per_chunk)
latency_tracker.add(new_latency)
if ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS < rtc_config.execution_horizon + new_delay:
logger.warning(
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
"Should be higher than inference delay + execution horizon."
)
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
logger.debug(
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
f"delay={new_delay}, queue_size={action_queue.qsize()}"
)
else:
time.sleep(0.01)
logger.info("[GET_ACTIONS] Action generation thread shutting down")
except Exception as e:
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
logger.error(traceback.format_exc())
shutdown_event.set()
sys.exit(1)
def actor_thread(
robot: RobotWrapper,
robot_action_processor, robot_action_processor,
dataset, action_queue: ActionQueue,
events, shutdown_event: Event,
episode_active: Event,
interpolator: ActionInterpolator, interpolator: ActionInterpolator,
robot_hz_tracker: HzTracker, robot_hz_tracker: HzTracker,
camera_fps: int,
effective_policy_fps: int,
robot_fps: int, robot_fps: int,
control_time_s: float, action_keys: list[str],
task: str, custom_kp: dict | None,
kp_scale: float | None = None, custom_kd: dict | None,
kd_scale: float | None = None, use_velocity_ff: bool,
use_velocity_ff: bool = False,
): ):
""" """Thread function to execute interpolated actions on the robot at high frequency."""
Run evaluation with decoupled camera and robot control: try:
- Camera captures at camera_fps (hardware limit) logger.info("[ACTOR] Starting actor thread")
- Policy inference runs when new camera frame is available
- Actions are consumed at effective_policy_fps (sped up by SPEED_MULTIPLIER) action_count = 0
- Robot receives interpolated commands at robot_fps (smoothest) action_interval = 1.0 / robot_fps
"""
from lerobot.scripts.lerobot_record import build_dataset_frame, make_robot_action while not shutdown_event.is_set():
from lerobot.utils.visualization_utils import log_rerun_data if not episode_active.is_set():
time.sleep(0.01)
camera_dt = 1.0 / camera_fps continue
policy_dt = 1.0 / effective_policy_fps
robot_dt = 1.0 / robot_fps start_time = time.perf_counter()
interpolator.reset()
robot_hz_tracker.reset()
policy.reset()
# Build custom gains if scaling is enabled
custom_kp = None
custom_kd = None
if kp_scale is not None or kd_scale is not None:
custom_kp = {}
custom_kd = {}
for arm in ["right", "left"]:
bus = robot.bus_right if arm == "right" else robot.bus_left
for i, motor_name in enumerate(bus.motors):
full_name = f"{arm}_{motor_name}"
default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp
default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
print(f"Custom gains: kp_scale={kp_scale}, kd_scale={kd_scale}")
if use_velocity_ff:
print("Velocity feedforward: enabled")
last_camera_time = -camera_dt
last_policy_action_time = -policy_dt
cached_observation = None
cached_robot_action = None
start_time = time.perf_counter()
print(f"\nStarting interpolated eval loop:")
print(f" Camera: {camera_fps}Hz | Policy actions consumed: {effective_policy_fps}Hz | Robot: {robot_fps}Hz")
while time.perf_counter() - start_time < control_time_s:
if events["exit_early"] or events["stop_recording"]:
break
loop_start = time.perf_counter() # Get new action from queue and update interpolator
elapsed = loop_start - start_time action = action_queue.get()
if action is not None:
# === CAMERA CAPTURE (at camera_fps, decoupled from robot) === interpolator.update(action.cpu())
if elapsed - last_camera_time >= camera_dt:
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation")
# Run policy inference with fresh observation # Get interpolated action for smooth control
action_values = predict_action( smooth_action, velocity = interpolator.get_interpolated_action()
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=task,
robot_type=robot.robot_type,
)
act_processed = make_robot_action(action_values, dataset.features) if smooth_action is not None:
cached_robot_action = robot_action_processor((act_processed, obs)) action_dict = {}
cached_observation = (obs_processed, observation_frame, act_processed) for i, key in enumerate(action_keys):
if i < len(smooth_action):
last_camera_time = elapsed action_dict[key] = smooth_action[i].item()
# === ACTION UPDATE (at effective_policy_fps, faster than camera if speed > 1) === action_processed = robot_action_processor((action_dict, None))
if elapsed - last_policy_action_time >= policy_dt and cached_robot_action is not None:
interpolator.update(cached_robot_action) vel_ff = None
last_policy_action_time = elapsed if use_velocity_ff and velocity is not None:
vel_ff = {}
# Save to dataset at effective policy rate for i, key in enumerate(action_keys):
if dataset is not None and cached_observation is not None: if i < len(velocity):
obs_processed, observation_frame, act_processed = cached_observation motor_name = key.replace(".pos", "")
action_frame = build_dataset_frame(dataset.features, act_processed, prefix="action") vel_ff[motor_name] = velocity[i].item()
frame = {**observation_frame, **action_frame, "task": task}
dataset.add_frame(frame) robot.send_action(action_processed, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff)
log_rerun_data(observation=obs_processed, action=act_processed) action_count += 1
# === ROBOT COMMAND (at robot_fps, highest rate for smoothness) === robot_hz_tracker.tick()
smooth_action, velocity = interpolator.get_interpolated_action()
if smooth_action is not None: dt_s = time.perf_counter() - start_time
vel_ff = velocity if use_velocity_ff else None sleep_time = max(0, action_interval - dt_s - 0.001)
robot.send_action(smooth_action, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff) if sleep_time > 0:
time.sleep(sleep_time)
robot_hz_tracker.tick()
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
# Maintain robot control rate except Exception as e:
sleep_time = robot_dt - (time.perf_counter() - loop_start) logger.error(f"[ACTOR] Fatal exception: {e}")
if sleep_time > 0: logger.error(traceback.format_exc())
time.sleep(sleep_time) shutdown_event.set()
sys.exit(1)
def build_custom_gains(robot, kp_scale: float | None, kd_scale: float | None) -> tuple[dict | None, dict | None]:
"""Build custom KP/KD gains for the robot."""
if kp_scale is None and kd_scale is None:
return None, None
# Print final stats custom_kp = {}
robot_hz = robot_hz_tracker.get_avg_hz() custom_kd = {}
if robot_hz: for arm in ["right", "left"]:
print(f"\nFinal average robot Hz: {robot_hz:.1f}") bus = robot.robot.bus_right if arm == "right" else robot.robot.bus_left
for i, motor_name in enumerate(bus.motors):
full_name = f"{arm}_{motor_name}"
default_kp = robot.robot.config.position_kp[i] if isinstance(robot.robot.config.position_kp, list) else robot.robot.config.position_kp
default_kd = robot.robot.config.position_kd[i] if isinstance(robot.robot.config.position_kd, list) else robot.robot.config.position_kd
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
return custom_kp, custom_kd
def main(): def main():
"""Main evaluation function.""" """Main evaluation function with RTC and interpolation."""
print("=" * 60) print("=" * 60)
print("OpenArms Policy Evaluation with Interpolation") print("OpenArms Policy Evaluation with RTC + Interpolation")
print("=" * 60) print("=" * 60)
print(f"\nModel: {HF_MODEL_ID}") print(f"\nModel: {HF_MODEL_ID}")
print(f"Dataset: {HF_EVAL_DATASET_ID}") print(f"Dataset: {HF_EVAL_DATASET_ID}")
print(f"Task: {TASK_DESCRIPTION}") print(f"Task: {TASK_DESCRIPTION}")
print(f"\n--- Timing ---") print(f"\n--- Timing ---")
print(f"Camera FPS: {CAMERA_FPS} (hardware limit)") print(f"Policy FPS: {POLICY_FPS}Hz")
print(f"Policy trained at: {POLICY_FPS}Hz") print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated)")
print(f"Speed multiplier: {SPEED_MULTIPLIER}x") print(f"\n--- RTC ---")
print(f"Effective policy FPS: {EFFECTIVE_POLICY_FPS}Hz (actions consumed)") print(f"RTC Enabled: {RTC_ENABLED}")
print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated commands)") print(f"Execution Horizon: {RTC_EXECUTION_HORIZON}")
print(f"\n--- PID Tuning ---") print(f"Max Guidance Weight: {RTC_MAX_GUIDANCE_WEIGHT}")
print(f"KP scale: {CUSTOM_KP_SCALE}") print(f"\n--- PID ---")
print(f"KD scale: {CUSTOM_KD_SCALE}") print(f"KP scale: {CUSTOM_KP_SCALE}, KD scale: {CUSTOM_KD_SCALE}")
print(f"Velocity feedforward: {USE_VELOCITY_FEEDFORWARD}") print(f"Velocity FF: {USE_VELOCITY_FEEDFORWARD}")
print(f"\n--- Episodes ---") print(f"\n--- Episodes ---")
print(f"Episodes: {NUM_EPISODES}") print(f"Episodes: {NUM_EPISODES}, Duration: {EPISODE_TIME_SEC}s")
print(f"Duration: {EPISODE_TIME_SEC}s per episode")
print(f"Reset time: {RESET_TIME_SEC}s")
print(f"Leader for resets: {USE_LEADER_FOR_RESETS}")
print("=" * 60) print("=" * 60)
shutdown_event = Event()
episode_active = Event()
follower_config = OpenArmsFollowerConfig( follower_config = OpenArmsFollowerConfig(
port_left=FOLLOWER_LEFT_PORT, port_left=FOLLOWER_LEFT_PORT,
@@ -346,6 +456,9 @@ def main():
if not follower.is_connected: if not follower.is_connected:
raise RuntimeError("Follower robot failed to connect!") raise RuntimeError("Follower robot failed to connect!")
robot = RobotWrapper(follower)
logger.info("Follower robot connected")
leader = None leader = None
if USE_LEADER_FOR_RESETS: if USE_LEADER_FOR_RESETS:
leader_config = OpenArmsLeaderConfig( leader_config = OpenArmsLeaderConfig(
@@ -366,9 +479,9 @@ def main():
leader.bus_right.enable_torque() leader.bus_right.enable_torque()
leader.bus_left.enable_torque() leader.bus_left.enable_torque()
time.sleep(0.1) time.sleep(0.1)
print(f"Leader connected with gravity compensation") print("Leader connected with gravity compensation")
else: else:
print(f"Leader connected (no gravity compensation)") print("Leader connected (no gravity compensation)")
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
@@ -401,10 +514,9 @@ def main():
leader.disconnect() leader.disconnect()
return return
# Dataset uses effective policy FPS (sped up rate)
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id=HF_EVAL_DATASET_ID, repo_id=HF_EVAL_DATASET_ID,
fps=EFFECTIVE_POLICY_FPS, fps=POLICY_FPS,
features=dataset_features, features=dataset_features,
robot_type=follower.name, robot_type=follower.name,
use_videos=True, use_videos=True,
@@ -412,53 +524,102 @@ def main():
image_writer_threads=12, image_writer_threads=12,
) )
# Load policy with RTC support
logger.info(f"Loading policy from: {HF_MODEL_ID}")
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
policy_config.pretrained_path = HF_MODEL_ID policy_config.pretrained_path = HF_MODEL_ID
policy = make_policy(policy_config, ds_meta=dataset.meta)
preprocessor, postprocessor = make_pre_post_processors( policy_class = get_policy_class(policy_config.type)
policy_cfg=policy.config, policy = policy_class.from_pretrained(HF_MODEL_ID, config=policy_config)
pretrained_path=HF_MODEL_ID,
dataset_stats=dataset.meta.stats, rtc_config = RTCConfig(
preprocessor_overrides={ enabled=RTC_ENABLED,
"device_processor": {"device": str(policy.config.device)} execution_horizon=RTC_EXECUTION_HORIZON,
}, max_guidance_weight=RTC_MAX_GUIDANCE_WEIGHT,
prefix_attention_schedule=RTCAttentionSchedule.EXP,
) )
policy.config.rtc_config = rtc_config
policy.init_rtc_processor()
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 support RTC"
policy = policy.to(DEVICE)
policy.eval()
logger.info(f"Policy loaded: {policy.name}")
print(f"\nRunning evaluation...") print(f"\nRunning evaluation...")
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
init_rerun(session_name="openarms_evaluation_interp") init_rerun(session_name="openarms_eval_rtc_interp")
interpolator = ActionInterpolator(effective_policy_fps=EFFECTIVE_POLICY_FPS, robot_fps=ROBOT_FPS) action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0) custom_kp, custom_kd = build_custom_gains(robot, CUSTOM_KP_SCALE, CUSTOM_KD_SCALE)
if custom_kp:
print(f"Custom gains applied")
if USE_VELOCITY_FEEDFORWARD:
print("Velocity feedforward: enabled")
episode_idx = 0 episode_idx = 0
get_actions_t = None
actor_t = None
try: try:
while episode_idx < NUM_EPISODES and not events["stop_recording"]: while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---") print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---")
interpolated_eval_loop( action_queue = ActionQueue(rtc_config)
robot=follower, interpolator = ActionInterpolator(policy_fps=POLICY_FPS, robot_fps=ROBOT_FPS)
policy=policy, robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0)
preprocessor=preprocessor,
postprocessor=postprocessor, get_actions_t = Thread(
robot_observation_processor=robot_observation_processor, target=get_actions_thread,
robot_action_processor=robot_action_processor, args=(
dataset=dataset, policy, robot, robot_observation_processor, action_queue,
events=events, shutdown_event, episode_active, rtc_config, POLICY_FPS,
interpolator=interpolator, TASK_DESCRIPTION, HF_MODEL_ID, DEVICE,
robot_hz_tracker=robot_hz_tracker, ),
camera_fps=CAMERA_FPS, daemon=True,
effective_policy_fps=EFFECTIVE_POLICY_FPS, name="GetActions",
robot_fps=ROBOT_FPS,
control_time_s=EPISODE_TIME_SEC,
task=TASK_DESCRIPTION,
kp_scale=CUSTOM_KP_SCALE,
kd_scale=CUSTOM_KD_SCALE,
use_velocity_ff=USE_VELOCITY_FEEDFORWARD,
) )
get_actions_t.start()
actor_t = Thread(
target=actor_thread,
args=(
robot, robot_action_processor, action_queue,
shutdown_event, episode_active, interpolator, robot_hz_tracker,
ROBOT_FPS, action_keys, custom_kp, custom_kd, USE_VELOCITY_FEEDFORWARD,
),
daemon=True,
name="Actor",
)
actor_t.start()
logger.info("Started inference and actor threads")
episode_active.set()
episode_start_time = time.time()
while (time.time() - episode_start_time) < EPISODE_TIME_SEC:
if events["exit_early"] or events["stop_recording"] or shutdown_event.is_set():
break
elapsed = time.time() - episode_start_time
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
robot_hz = robot_hz_tracker.get_avg_hz()
logger.info(
f"Progress: {elapsed:.0f}/{EPISODE_TIME_SEC}s, "
f"queue={action_queue.qsize()}, hz={robot_hz:.1f if robot_hz else 0}"
)
time.sleep(0.5)
episode_active.clear()
robot_hz = robot_hz_tracker.get_avg_hz()
logger.info(f"Episode {episode_idx + 1} done. Avg Hz: {robot_hz:.1f if robot_hz else 0}")
if events["rerecord_episode"]: if events["rerecord_episode"]:
log_say("Re-recording episode") log_say("Re-recording episode")
@@ -566,6 +727,15 @@ def main():
print("\n\nInterrupted by user") print("\n\nInterrupted by user")
finally: finally:
shutdown_event.set()
episode_active.clear()
if get_actions_t is not None and get_actions_t.is_alive():
get_actions_t.join(timeout=2.0)
if actor_t is not None and actor_t.is_alive():
actor_t.join(timeout=2.0)
if leader: if leader:
leader.bus_right.disable_torque() leader.bus_right.disable_torque()
leader.bus_left.disable_torque() leader.bus_left.disable_torque()
@@ -573,6 +743,7 @@ def main():
leader.disconnect() leader.disconnect()
follower.disconnect() follower.disconnect()
logger.info("Follower disconnected")
if listener is not None: if listener is not None:
listener.stop() listener.stop()