mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
do evaluate with rtc
This commit is contained in:
@@ -15,11 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation with RTC and Interpolation
|
||||
OpenArms Policy Evaluation with Interpolation
|
||||
|
||||
Evaluates a trained policy with:
|
||||
- RTC (Real-Time Chunking) for async inference - decouples policy from robot loop
|
||||
- Smooth action interpolation for high-frequency robot control
|
||||
Evaluates a trained policy with smooth action interpolation:
|
||||
- Decoupled camera capture (CAMERA_FPS) from robot control (ROBOT_FPS)
|
||||
- Speed multiplier to execute actions faster than training
|
||||
- Velocity feedforward for smoother tracking
|
||||
- Adjustable PID gains
|
||||
|
||||
@@ -27,41 +27,27 @@ Example usage:
|
||||
python examples/openarms/evaluate_interpolation.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.datasets.utils import combine_feature_dicts
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators.openarms.config_openarms_leader import OpenArmsLeaderConfig
|
||||
from lerobot.teleoperators.openarms.openarms_leader import OpenArmsLeader
|
||||
from lerobot.utils.control_utils import init_keyboard_listener
|
||||
from lerobot.utils.utils import log_say
|
||||
from lerobot.utils.control_utils import init_keyboard_listener, predict_action
|
||||
from lerobot.utils.utils import log_say, get_safe_torch_device
|
||||
from lerobot.utils.visualization_utils import init_rerun
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ======================== MODEL & TASK CONFIG ========================
|
||||
HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0" # TODO: Replace with your trained model
|
||||
@@ -71,109 +57,83 @@ TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task
|
||||
# ======================== TIMING CONFIG ========================
|
||||
CAMERA_FPS = 30 # Camera hardware limit (fixed)
|
||||
POLICY_FPS = 30 # What the policy was trained with
|
||||
SPEED_MULTIPLIER = 1.2 # Execute actions faster (1.0 = normal, 1.2 = 20% faster)
|
||||
ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation)
|
||||
|
||||
# Derived values
|
||||
EFFECTIVE_POLICY_FPS = int(POLICY_FPS * SPEED_MULTIPLIER) # How fast we consume actions (36Hz at 1.2x)
|
||||
|
||||
NUM_EPISODES = 1
|
||||
EPISODE_TIME_SEC = 300
|
||||
RESET_TIME_SEC = 60
|
||||
|
||||
# ======================== RTC CONFIG ========================
|
||||
RTC_ENABLED = True
|
||||
RTC_EXECUTION_HORIZON = 20
|
||||
RTC_MAX_GUIDANCE_WEIGHT = 5.0
|
||||
ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS = 30 # Should be > inference_delay + execution_horizon
|
||||
|
||||
# ======================== PID TUNING ========================
|
||||
CUSTOM_KP_SCALE = 1.0 # Scale factor for position gain (0.5-1.0, lower = smoother)
|
||||
CUSTOM_KD_SCALE = 1.0 # Scale factor for damping gain (1.0-2.0, higher = less overshoot)
|
||||
USE_VELOCITY_FEEDFORWARD = False # Enable velocity feedforward for smoother tracking
|
||||
# Set to None to use robot config defaults
|
||||
CUSTOM_KP_SCALE = 0.7 # Scale factor for position gain (0.5-1.0, lower = smoother)
|
||||
CUSTOM_KD_SCALE = 1.3 # Scale factor for damping gain (1.0-2.0, higher = less overshoot)
|
||||
USE_VELOCITY_FEEDFORWARD = True # Enable velocity feedforward for smoother tracking
|
||||
|
||||
# ======================== ROBOT CONFIG ========================
|
||||
FOLLOWER_LEFT_PORT = "can0"
|
||||
FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
USE_LEADER_FOR_RESETS = False
|
||||
USE_LEADER_FOR_RESETS = True
|
||||
LEADER_LEFT_PORT = "can2"
|
||||
LEADER_RIGHT_PORT = "can3"
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
# Camera config uses CAMERA_FPS (hardware limit)
|
||||
CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=1280, height=720, fps=CAMERA_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=1280, height=720, fps=CAMERA_FPS),
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=CAMERA_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=CAMERA_FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=CAMERA_FPS),
|
||||
}
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: OpenArmsFollower):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict, **kwargs) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action, **kwargs)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolate between consecutive actions for smoother robot control."""
|
||||
|
||||
def __init__(self, policy_fps: int, robot_fps: int):
|
||||
self.policy_fps = policy_fps
|
||||
"""Interpolate between policy actions for smoother robot control with velocity estimation."""
|
||||
|
||||
def __init__(self, effective_policy_fps: int, robot_fps: int):
|
||||
self.effective_policy_fps = effective_policy_fps
|
||||
self.robot_fps = robot_fps
|
||||
self.substeps_per_policy_step = robot_fps / policy_fps
|
||||
self.prev_action: Tensor | None = None
|
||||
self.curr_action: Tensor | None = None
|
||||
self.substeps_per_policy_step = robot_fps / effective_policy_fps
|
||||
self.prev_action: dict | None = None
|
||||
self.curr_action: dict | None = None
|
||||
self.substep = 0
|
||||
self.last_interpolated: Tensor | None = None
|
||||
|
||||
def update(self, new_action: Tensor) -> None:
|
||||
self.last_interpolated: dict | None = None
|
||||
|
||||
def update(self, new_action: dict) -> None:
|
||||
self.prev_action = self.curr_action
|
||||
self.curr_action = new_action
|
||||
self.substep = 0
|
||||
|
||||
def get_interpolated_action(self) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Returns (interpolated_action, estimated_velocity)"""
|
||||
|
||||
def get_interpolated_action(self) -> tuple[dict | None, dict | None]:
|
||||
"""Returns (interpolated_position, estimated_velocity_deg_per_sec)"""
|
||||
if self.curr_action is None:
|
||||
return None, None
|
||||
if self.prev_action is None:
|
||||
self.last_interpolated = self.curr_action.clone()
|
||||
return self.curr_action, torch.zeros_like(self.curr_action)
|
||||
|
||||
self.last_interpolated = self.curr_action.copy()
|
||||
return self.curr_action, {k: 0.0 for k in self.curr_action}
|
||||
|
||||
t = min(self.substep / self.substeps_per_policy_step, 1.0)
|
||||
self.substep += 1
|
||||
|
||||
interpolated = self.prev_action * (1 - t) + self.curr_action * t
|
||||
|
||||
|
||||
interpolated = {}
|
||||
velocity = {}
|
||||
dt = 1.0 / self.robot_fps
|
||||
if self.last_interpolated is not None:
|
||||
velocity = (interpolated - self.last_interpolated) / dt
|
||||
else:
|
||||
velocity = (self.curr_action - self.prev_action) * self.policy_fps
|
||||
|
||||
self.last_interpolated = interpolated.clone()
|
||||
|
||||
for key in self.curr_action:
|
||||
prev = self.prev_action.get(key, self.curr_action[key])
|
||||
curr = self.curr_action[key]
|
||||
interpolated[key] = prev * (1 - t) + curr * t
|
||||
|
||||
if self.last_interpolated is not None and key in self.last_interpolated:
|
||||
velocity[key] = (interpolated[key] - self.last_interpolated[key]) / dt
|
||||
else:
|
||||
velocity[key] = (curr - prev) * self.effective_policy_fps
|
||||
|
||||
self.last_interpolated = interpolated.copy()
|
||||
return interpolated, velocity
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.prev_action = None
|
||||
self.curr_action = None
|
||||
@@ -215,253 +175,160 @@ class HzTracker:
|
||||
self.last_print_time = 0
|
||||
|
||||
|
||||
def get_actions_thread(
|
||||
def interpolated_eval_loop(
|
||||
robot,
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
episode_active: Event,
|
||||
rtc_config: RTCConfig,
|
||||
policy_fps: int,
|
||||
task: str,
|
||||
pretrained_path: str,
|
||||
device: str,
|
||||
):
|
||||
"""Thread function to asynchronously generate action chunks from the policy."""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting action generation thread")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / policy_fps
|
||||
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
policy_device = device
|
||||
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=pretrained_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={"device_processor": {"device": device}},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
|
||||
|
||||
get_actions_threshold = ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS if rtc_config.enabled else 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
# Filter out non-feature keys (like _timing_breakdown)
|
||||
obs_for_frame = {k: v for k, v in obs_processed.items() if not k.startswith("_")}
|
||||
|
||||
# Check for missing camera keys and wait for them if needed
|
||||
expected_img_keys = [k.removeprefix("observation.images.")
|
||||
for k in hw_features if "images" in k]
|
||||
missing_keys = [k for k in expected_img_keys if k not in obs_for_frame]
|
||||
|
||||
if missing_keys:
|
||||
logger.warning(f"[GET_ACTIONS] Missing camera keys: {missing_keys}, retrying...")
|
||||
# Retry observation to get camera frames
|
||||
for _ in range(5):
|
||||
time.sleep(0.05)
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
obs_for_frame = {k: v for k, v in obs_processed.items() if not k.startswith("_")}
|
||||
missing_keys = [k for k in expected_img_keys if k not in obs_for_frame]
|
||||
if not missing_keys:
|
||||
break
|
||||
|
||||
if missing_keys:
|
||||
logger.error(f"[GET_ACTIONS] Still missing keys after retries: {missing_keys}")
|
||||
logger.error(f"[GET_ACTIONS] Available keys: {list(obs_for_frame.keys())}")
|
||||
continue # Skip this inference cycle
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
hw_features, obs_for_frame, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [task]
|
||||
obs_with_policy_features["robot_type"] = robot.name
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if ACTION_QUEUE_SIZE_TO_GET_NEW_ACTIONS < rtc_config.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
|
||||
"Should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
|
||||
f"delay={new_delay}, queue_size={action_queue.qsize()}"
|
||||
)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[GET_ACTIONS] Action generation thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_thread(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
episode_active: Event,
|
||||
dataset,
|
||||
events,
|
||||
interpolator: ActionInterpolator,
|
||||
robot_hz_tracker: HzTracker,
|
||||
camera_fps: int,
|
||||
effective_policy_fps: int,
|
||||
robot_fps: int,
|
||||
action_keys: list[str],
|
||||
custom_kp: dict | None,
|
||||
custom_kd: dict | None,
|
||||
use_velocity_ff: bool,
|
||||
control_time_s: float,
|
||||
task: str,
|
||||
kp_scale: float | None = None,
|
||||
kd_scale: float | None = None,
|
||||
use_velocity_ff: bool = False,
|
||||
):
|
||||
"""Thread function to execute interpolated actions on the robot at high frequency."""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / robot_fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
action = action_queue.get()
|
||||
if action is not None:
|
||||
interpolator.update(action.cpu())
|
||||
|
||||
smooth_action, velocity = interpolator.get_interpolated_action()
|
||||
|
||||
if smooth_action is not None:
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(smooth_action):
|
||||
action_dict[key] = smooth_action[i].item()
|
||||
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
|
||||
vel_ff = None
|
||||
if use_velocity_ff and velocity is not None:
|
||||
vel_ff = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(velocity):
|
||||
motor_name = key.replace(".pos", "")
|
||||
vel_ff[motor_name] = velocity[i].item()
|
||||
|
||||
robot.send_action(action_processed, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff)
|
||||
action_count += 1
|
||||
|
||||
robot_hz_tracker.tick()
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, action_interval - dt_s - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def build_custom_gains(robot, kp_scale: float | None, kd_scale: float | None) -> tuple[dict | None, dict | None]:
|
||||
"""Build custom KP/KD gains for the robot."""
|
||||
if kp_scale is None and kd_scale is None:
|
||||
return None, None
|
||||
|
||||
custom_kp = {}
|
||||
custom_kd = {}
|
||||
for arm in ["right", "left"]:
|
||||
bus = robot.robot.bus_right if arm == "right" else robot.robot.bus_left
|
||||
for i, motor_name in enumerate(bus.motors):
|
||||
full_name = f"{arm}_{motor_name}"
|
||||
default_kp = robot.robot.config.position_kp[i] if isinstance(robot.robot.config.position_kp, list) else robot.robot.config.position_kp
|
||||
default_kd = robot.robot.config.position_kd[i] if isinstance(robot.robot.config.position_kd, list) else robot.robot.config.position_kd
|
||||
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
|
||||
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
|
||||
"""
|
||||
Run evaluation with decoupled camera and robot control:
|
||||
- Camera captures at camera_fps (hardware limit)
|
||||
- Policy inference runs when new camera frame is available
|
||||
- Actions are consumed at effective_policy_fps (sped up by SPEED_MULTIPLIER)
|
||||
- Robot receives interpolated commands at robot_fps (smoothest)
|
||||
"""
|
||||
from lerobot.scripts.lerobot_record import build_dataset_frame, make_robot_action
|
||||
from lerobot.utils.visualization_utils import log_rerun_data
|
||||
|
||||
return custom_kp, custom_kd
|
||||
camera_dt = 1.0 / camera_fps
|
||||
policy_dt = 1.0 / effective_policy_fps
|
||||
robot_dt = 1.0 / robot_fps
|
||||
|
||||
interpolator.reset()
|
||||
robot_hz_tracker.reset()
|
||||
policy.reset()
|
||||
|
||||
# Build custom gains if scaling is enabled
|
||||
custom_kp = None
|
||||
custom_kd = None
|
||||
if kp_scale is not None or kd_scale is not None:
|
||||
custom_kp = {}
|
||||
custom_kd = {}
|
||||
for arm in ["right", "left"]:
|
||||
bus = robot.bus_right if arm == "right" else robot.bus_left
|
||||
for i, motor_name in enumerate(bus.motors):
|
||||
full_name = f"{arm}_{motor_name}"
|
||||
default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp
|
||||
default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd
|
||||
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
|
||||
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
|
||||
print(f"Custom gains: kp_scale={kp_scale}, kd_scale={kd_scale}")
|
||||
|
||||
if use_velocity_ff:
|
||||
print("Velocity feedforward: enabled")
|
||||
|
||||
last_camera_time = -camera_dt
|
||||
last_policy_action_time = -policy_dt
|
||||
cached_observation = None
|
||||
cached_robot_action = None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
print(f"\nStarting interpolated eval loop:")
|
||||
print(f" Camera: {camera_fps}Hz | Policy actions consumed: {effective_policy_fps}Hz | Robot: {robot_fps}Hz")
|
||||
|
||||
while time.perf_counter() - start_time < control_time_s:
|
||||
if events["exit_early"] or events["stop_recording"]:
|
||||
break
|
||||
|
||||
loop_start = time.perf_counter()
|
||||
elapsed = loop_start - start_time
|
||||
|
||||
# === CAMERA CAPTURE (at camera_fps, decoupled from robot) ===
|
||||
if elapsed - last_camera_time >= camera_dt:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation")
|
||||
|
||||
# Run policy inference with fresh observation
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=task,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
|
||||
act_processed = make_robot_action(action_values, dataset.features)
|
||||
cached_robot_action = robot_action_processor((act_processed, obs))
|
||||
cached_observation = (obs_processed, observation_frame, act_processed)
|
||||
|
||||
last_camera_time = elapsed
|
||||
|
||||
# === ACTION UPDATE (at effective_policy_fps, faster than camera if speed > 1) ===
|
||||
if elapsed - last_policy_action_time >= policy_dt and cached_robot_action is not None:
|
||||
interpolator.update(cached_robot_action)
|
||||
last_policy_action_time = elapsed
|
||||
|
||||
# Save to dataset at effective policy rate
|
||||
if dataset is not None and cached_observation is not None:
|
||||
obs_processed, observation_frame, act_processed = cached_observation
|
||||
action_frame = build_dataset_frame(dataset.features, act_processed, prefix="action")
|
||||
frame = {**observation_frame, **action_frame, "task": task}
|
||||
dataset.add_frame(frame)
|
||||
log_rerun_data(observation=obs_processed, action=act_processed)
|
||||
|
||||
# === ROBOT COMMAND (at robot_fps, highest rate for smoothness) ===
|
||||
smooth_action, velocity = interpolator.get_interpolated_action()
|
||||
if smooth_action is not None:
|
||||
vel_ff = velocity if use_velocity_ff else None
|
||||
robot.send_action(smooth_action, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff)
|
||||
|
||||
robot_hz_tracker.tick()
|
||||
|
||||
# Maintain robot control rate
|
||||
sleep_time = robot_dt - (time.perf_counter() - loop_start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# Print final stats
|
||||
robot_hz = robot_hz_tracker.get_avg_hz()
|
||||
if robot_hz:
|
||||
print(f"\nFinal average robot Hz: {robot_hz:.1f}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main evaluation function with RTC and interpolation."""
|
||||
"""Main evaluation function."""
|
||||
print("=" * 60)
|
||||
print("OpenArms Policy Evaluation with RTC + Interpolation")
|
||||
print("OpenArms Policy Evaluation with Interpolation")
|
||||
print("=" * 60)
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"\n--- Timing ---")
|
||||
print(f"Policy FPS: {POLICY_FPS}Hz")
|
||||
print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated)")
|
||||
print(f"\n--- RTC ---")
|
||||
print(f"RTC Enabled: {RTC_ENABLED}")
|
||||
print(f"Execution Horizon: {RTC_EXECUTION_HORIZON}")
|
||||
print(f"Max Guidance Weight: {RTC_MAX_GUIDANCE_WEIGHT}")
|
||||
print(f"\n--- PID ---")
|
||||
print(f"KP scale: {CUSTOM_KP_SCALE}, KD scale: {CUSTOM_KD_SCALE}")
|
||||
print(f"Velocity FF: {USE_VELOCITY_FEEDFORWARD}")
|
||||
print(f"Camera FPS: {CAMERA_FPS} (hardware limit)")
|
||||
print(f"Policy trained at: {POLICY_FPS}Hz")
|
||||
print(f"Speed multiplier: {SPEED_MULTIPLIER}x")
|
||||
print(f"Effective policy FPS: {EFFECTIVE_POLICY_FPS}Hz (actions consumed)")
|
||||
print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated commands)")
|
||||
print(f"\n--- PID Tuning ---")
|
||||
print(f"KP scale: {CUSTOM_KP_SCALE}")
|
||||
print(f"KD scale: {CUSTOM_KD_SCALE}")
|
||||
print(f"Velocity feedforward: {USE_VELOCITY_FEEDFORWARD}")
|
||||
print(f"\n--- Episodes ---")
|
||||
print(f"Episodes: {NUM_EPISODES}, Duration: {EPISODE_TIME_SEC}s")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Duration: {EPISODE_TIME_SEC}s per episode")
|
||||
print(f"Reset time: {RESET_TIME_SEC}s")
|
||||
print(f"Leader for resets: {USE_LEADER_FOR_RESETS}")
|
||||
print("=" * 60)
|
||||
|
||||
shutdown_event = Event()
|
||||
episode_active = Event()
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
@@ -479,9 +346,6 @@ def main():
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
robot = RobotWrapper(follower)
|
||||
logger.info("Follower robot connected")
|
||||
|
||||
leader = None
|
||||
if USE_LEADER_FOR_RESETS:
|
||||
leader_config = OpenArmsLeaderConfig(
|
||||
@@ -502,9 +366,9 @@ def main():
|
||||
leader.bus_right.enable_torque()
|
||||
leader.bus_left.enable_torque()
|
||||
time.sleep(0.1)
|
||||
print("Leader connected with gravity compensation")
|
||||
print(f"Leader connected with gravity compensation")
|
||||
else:
|
||||
print("Leader connected (no gravity compensation)")
|
||||
print(f"Leader connected (no gravity compensation)")
|
||||
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
@@ -537,9 +401,10 @@ def main():
|
||||
leader.disconnect()
|
||||
return
|
||||
|
||||
# Dataset uses effective policy FPS (sped up rate)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=POLICY_FPS,
|
||||
fps=EFFECTIVE_POLICY_FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
@@ -547,104 +412,53 @@ def main():
|
||||
image_writer_threads=12,
|
||||
)
|
||||
|
||||
# Load policy with RTC support
|
||||
logger.info(f"Loading policy from: {HF_MODEL_ID}")
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
policy = policy_class.from_pretrained(HF_MODEL_ID, config=policy_config)
|
||||
|
||||
rtc_config = RTCConfig(
|
||||
enabled=RTC_ENABLED,
|
||||
execution_horizon=RTC_EXECUTION_HORIZON,
|
||||
max_guidance_weight=RTC_MAX_GUIDANCE_WEIGHT,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
)
|
||||
policy.config.rtc_config = rtc_config
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 support RTC"
|
||||
|
||||
policy = policy.to(DEVICE)
|
||||
policy.eval()
|
||||
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
print(f"\nRunning evaluation...")
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_eval_rtc_interp")
|
||||
init_rerun(session_name="openarms_evaluation_interp")
|
||||
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
custom_kp, custom_kd = build_custom_gains(robot, CUSTOM_KP_SCALE, CUSTOM_KD_SCALE)
|
||||
|
||||
if custom_kp:
|
||||
print(f"Custom gains applied")
|
||||
if USE_VELOCITY_FEEDFORWARD:
|
||||
print("Velocity feedforward: enabled")
|
||||
interpolator = ActionInterpolator(effective_policy_fps=EFFECTIVE_POLICY_FPS, robot_fps=ROBOT_FPS)
|
||||
robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0)
|
||||
|
||||
episode_idx = 0
|
||||
get_actions_t = None
|
||||
actor_t = None
|
||||
|
||||
try:
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---")
|
||||
|
||||
action_queue = ActionQueue(rtc_config)
|
||||
interpolator = ActionInterpolator(policy_fps=POLICY_FPS, robot_fps=ROBOT_FPS)
|
||||
robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0)
|
||||
|
||||
get_actions_t = Thread(
|
||||
target=get_actions_thread,
|
||||
args=(
|
||||
policy, robot, robot_observation_processor, action_queue,
|
||||
shutdown_event, episode_active, rtc_config, POLICY_FPS,
|
||||
TASK_DESCRIPTION, HF_MODEL_ID, DEVICE,
|
||||
),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
interpolated_eval_loop(
|
||||
robot=follower,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
interpolator=interpolator,
|
||||
robot_hz_tracker=robot_hz_tracker,
|
||||
camera_fps=CAMERA_FPS,
|
||||
effective_policy_fps=EFFECTIVE_POLICY_FPS,
|
||||
robot_fps=ROBOT_FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
task=TASK_DESCRIPTION,
|
||||
kp_scale=CUSTOM_KP_SCALE,
|
||||
kd_scale=CUSTOM_KD_SCALE,
|
||||
use_velocity_ff=USE_VELOCITY_FEEDFORWARD,
|
||||
)
|
||||
get_actions_t.start()
|
||||
|
||||
actor_t = Thread(
|
||||
target=actor_thread,
|
||||
args=(
|
||||
robot, robot_action_processor, action_queue,
|
||||
shutdown_event, episode_active, interpolator, robot_hz_tracker,
|
||||
ROBOT_FPS, action_keys, custom_kp, custom_kd, USE_VELOCITY_FEEDFORWARD,
|
||||
),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_t.start()
|
||||
|
||||
logger.info("Started inference and actor threads")
|
||||
|
||||
episode_active.set()
|
||||
episode_start_time = time.time()
|
||||
|
||||
while (time.time() - episode_start_time) < EPISODE_TIME_SEC:
|
||||
if events["exit_early"] or events["stop_recording"] or shutdown_event.is_set():
|
||||
break
|
||||
|
||||
elapsed = time.time() - episode_start_time
|
||||
if int(elapsed) % 10 == 0 and int(elapsed) > 0:
|
||||
robot_hz = robot_hz_tracker.get_avg_hz()
|
||||
hz_str = f"{robot_hz:.1f}" if robot_hz else "N/A"
|
||||
logger.info(
|
||||
f"Progress: {elapsed:.0f}/{EPISODE_TIME_SEC}s, "
|
||||
f"queue={action_queue.qsize()}, hz={hz_str}"
|
||||
)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
episode_active.clear()
|
||||
|
||||
robot_hz = robot_hz_tracker.get_avg_hz()
|
||||
hz_str = f"{robot_hz:.1f}" if robot_hz else "N/A"
|
||||
logger.info(f"Episode {episode_idx + 1} done. Avg Hz: {hz_str}")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
@@ -752,15 +566,6 @@ def main():
|
||||
print("\n\nInterrupted by user")
|
||||
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
episode_active.clear()
|
||||
|
||||
if get_actions_t is not None and get_actions_t.is_alive():
|
||||
get_actions_t.join(timeout=2.0)
|
||||
|
||||
if actor_t is not None and actor_t.is_alive():
|
||||
actor_t.join(timeout=2.0)
|
||||
|
||||
if leader:
|
||||
leader.bus_right.disable_torque()
|
||||
leader.bus_left.disable_torque()
|
||||
@@ -768,7 +573,6 @@ def main():
|
||||
leader.disconnect()
|
||||
|
||||
follower.disconnect()
|
||||
logger.info("Follower disconnected")
|
||||
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
@@ -0,0 +1,882 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
OpenArms Policy Evaluation with RTC + Interpolation
|
||||
|
||||
Combines Real-Time Chunking (RTC) with smooth action interpolation:
|
||||
- RTC for reactive motion despite high inference latency
|
||||
- Action interpolation for smooth robot movements
|
||||
- Speed multiplier to execute faster than training
|
||||
- Velocity feedforward and PID tuning
|
||||
- Decoupled inference (async) from robot control
|
||||
|
||||
Example usage:
|
||||
python examples/openarms/evaluate_with_rtc_interpolation.py
|
||||
|
||||
# With custom RTC parameters
|
||||
python examples/openarms/evaluate_with_rtc_interpolation.py \
|
||||
--rtc.execution_horizon=12 \
|
||||
--rtc.max_guidance_weight=10.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Default Configuration Constants
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_HF_MODEL_ID = "lerobot-data-collection/three-folds-pi0"
|
||||
DEFAULT_HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_rtc_interp"
|
||||
DEFAULT_TASK_DESCRIPTION = "three-folds-dataset"
|
||||
|
||||
DEFAULT_NUM_EPISODES = 1
|
||||
DEFAULT_CAMERA_FPS = 30 # Camera hardware limit
|
||||
DEFAULT_POLICY_FPS = 30 # What the policy was trained with
|
||||
DEFAULT_SPEED_MULTIPLIER = 1.0 # Execute actions faster (1.0 = normal, 1.2 = 20% faster)
|
||||
DEFAULT_ROBOT_FPS = 50 # Robot command rate (higher = smoother)
|
||||
DEFAULT_EPISODE_TIME_SEC = 300
|
||||
DEFAULT_RESET_TIME_SEC = 60
|
||||
|
||||
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
|
||||
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
|
||||
|
||||
# PID tuning defaults
|
||||
DEFAULT_KP_SCALE = 0.7 # Lower = smoother but slower
|
||||
DEFAULT_KD_SCALE = 1.3 # Higher = less overshoot
|
||||
DEFAULT_USE_VELOCITY_FF = True # Velocity feedforward
|
||||
|
||||
DEFAULT_CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_CAMERA_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_CAMERA_FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_CAMERA_FPS),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Interpolator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class ActionInterpolator:
|
||||
"""Interpolate between RTC actions for smoother robot control with velocity estimation."""
|
||||
|
||||
def __init__(self, robot_fps: int):
|
||||
self.robot_fps = robot_fps
|
||||
self.prev_action: Tensor | None = None
|
||||
self.curr_action: Tensor | None = None
|
||||
self.prev_time: float = 0
|
||||
self.curr_time: float = 0
|
||||
self.last_interpolated: Tensor | None = None
|
||||
|
||||
def update(self, new_action: Tensor) -> None:
|
||||
self.prev_action = self.curr_action
|
||||
self.prev_time = self.curr_time
|
||||
self.curr_action = new_action
|
||||
self.curr_time = time.perf_counter()
|
||||
|
||||
def get_interpolated_action(self) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Returns (interpolated_position, estimated_velocity)"""
|
||||
if self.curr_action is None:
|
||||
return None, None
|
||||
if self.prev_action is None:
|
||||
self.last_interpolated = self.curr_action.clone()
|
||||
return self.curr_action, torch.zeros_like(self.curr_action)
|
||||
|
||||
# Time-based interpolation
|
||||
current_time = time.perf_counter()
|
||||
dt_actions = self.curr_time - self.prev_time
|
||||
if dt_actions <= 0:
|
||||
dt_actions = 1.0 / 30 # Fallback
|
||||
|
||||
t = (current_time - self.prev_time) / dt_actions
|
||||
t = max(0.0, min(t, 1.5)) # Allow slight extrapolation
|
||||
|
||||
interpolated = self.prev_action + t * (self.curr_action - self.prev_action)
|
||||
|
||||
# Estimate velocity
|
||||
dt_robot = 1.0 / self.robot_fps
|
||||
if self.last_interpolated is not None:
|
||||
velocity = (interpolated - self.last_interpolated) / dt_robot
|
||||
else:
|
||||
velocity = (self.curr_action - self.prev_action) / dt_actions
|
||||
|
||||
self.last_interpolated = interpolated.clone()
|
||||
return interpolated, velocity
|
||||
|
||||
def reset(self):
|
||||
self.prev_action = None
|
||||
self.curr_action = None
|
||||
self.prev_time = 0
|
||||
self.curr_time = 0
|
||||
self.last_interpolated = None
|
||||
|
||||
|
||||
class HzTracker:
|
||||
"""Track and display actual loop frequency."""
|
||||
|
||||
def __init__(self, name: str = "Loop", window_size: int = 100, print_interval: float = 2.0):
|
||||
self.name = name
|
||||
self.timestamps = deque(maxlen=window_size)
|
||||
self.last_print_time = 0
|
||||
self.print_interval = print_interval
|
||||
|
||||
def tick(self) -> float | None:
|
||||
now = time.perf_counter()
|
||||
self.timestamps.append(now)
|
||||
|
||||
if len(self.timestamps) < 2:
|
||||
return None
|
||||
|
||||
hz = (len(self.timestamps) - 1) / (self.timestamps[-1] - self.timestamps[0])
|
||||
|
||||
if now - self.last_print_time >= self.print_interval:
|
||||
print(f"{self.name} Hz: {hz:.1f}")
|
||||
self.last_print_time = now
|
||||
|
||||
return hz
|
||||
|
||||
def get_avg_hz(self) -> float | None:
|
||||
if len(self.timestamps) < 2:
|
||||
return None
|
||||
return (len(self.timestamps) - 1) / (self.timestamps[-1] - self.timestamps[0])
|
||||
|
||||
def reset(self):
|
||||
self.timestamps.clear()
|
||||
self.last_print_time = 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Thread-Safe Robot Wrapper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations with custom PID gains."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
robot: OpenArmsFollower,
|
||||
custom_kp: dict | None = None,
|
||||
custom_kd: dict | None = None,
|
||||
use_velocity_ff: bool = False,
|
||||
):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
self.custom_kp = custom_kp
|
||||
self.custom_kd = custom_kd
|
||||
self.use_velocity_ff = use_velocity_ff
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict, velocity_ff: dict | None = None) -> None:
|
||||
with self.lock:
|
||||
vel_ff = velocity_ff if self.use_velocity_ff else None
|
||||
self.robot.send_action(
|
||||
action,
|
||||
custom_kp=self.custom_kp,
|
||||
custom_kd=self.custom_kd,
|
||||
velocity_feedforward=vel_ff,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenArmsRTCInterpEvalConfig(HubMixin):
|
||||
"""Configuration for OpenArms evaluation with RTC + Interpolation."""
|
||||
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
model_id: str = DEFAULT_HF_MODEL_ID
|
||||
eval_dataset_id: str = DEFAULT_HF_EVAL_DATASET_ID
|
||||
task: str = DEFAULT_TASK_DESCRIPTION
|
||||
|
||||
num_episodes: int = DEFAULT_NUM_EPISODES
|
||||
camera_fps: float = DEFAULT_CAMERA_FPS
|
||||
policy_fps: float = DEFAULT_POLICY_FPS
|
||||
speed_multiplier: float = DEFAULT_SPEED_MULTIPLIER
|
||||
robot_fps: float = DEFAULT_ROBOT_FPS
|
||||
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
||||
reset_time_sec: float = DEFAULT_RESET_TIME_SEC
|
||||
|
||||
# PID tuning
|
||||
kp_scale: float | None = DEFAULT_KP_SCALE
|
||||
kd_scale: float | None = DEFAULT_KD_SCALE
|
||||
use_velocity_ff: bool = DEFAULT_USE_VELOCITY_FF
|
||||
|
||||
follower_left_port: str = DEFAULT_FOLLOWER_LEFT_PORT
|
||||
follower_right_port: str = DEFAULT_FOLLOWER_RIGHT_PORT
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
# Should be higher than inference_delay + execution_horizon
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
record_dataset: bool = True
|
||||
push_to_hub: bool = True
|
||||
|
||||
use_torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor"
|
||||
torch_compile_mode: str = "default"
|
||||
torch_compile_disable_cudagraphs: bool = True
|
||||
|
||||
@property
|
||||
def effective_policy_fps(self) -> int:
|
||||
return int(self.policy_fps * self.speed_multiplier)
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
self.model_id = policy_path
|
||||
elif self.model_id:
|
||||
self.policy = PreTrainedConfig.from_pretrained(self.model_id)
|
||||
self.policy.pretrained_path = self.model_id
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Action Generation Thread (RTC)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_actions_thread(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCInterpEvalConfig,
|
||||
episode_active: Event,
|
||||
):
|
||||
"""Thread function to asynchronously generate action chunks from the policy using RTC."""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting RTC action generation thread")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.effective_policy_fps # Use effective FPS with speed multiplier
|
||||
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
hw_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = robot.name
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] action_queue_size_to_get_new_actions too small. "
|
||||
"Should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[GET_ACTIONS] Generated chunk, latency={new_latency:.3f}s, "
|
||||
f"delay={new_delay}, queue_size={action_queue.qsize()}"
|
||||
)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[GET_ACTIONS] Action generation thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Actor Thread with Interpolation
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def actor_thread(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: OpenArmsRTCInterpEvalConfig,
|
||||
episode_active: Event,
|
||||
dataset: LeRobotDataset | None,
|
||||
dataset_lock: Lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
interpolator: ActionInterpolator,
|
||||
hz_tracker: HzTracker,
|
||||
):
|
||||
"""Thread function to execute interpolated actions on the robot at high frequency."""
|
||||
try:
|
||||
logger.info(f"[ACTOR] Starting actor thread with interpolation at {cfg.robot_fps}Hz")
|
||||
|
||||
action_count = 0
|
||||
robot_interval = 1.0 / cfg.robot_fps # High frequency robot control
|
||||
effective_policy_interval = 1.0 / cfg.effective_policy_fps # Action consume rate
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
last_action_consume_time = 0
|
||||
interpolator.reset()
|
||||
hz_tracker.reset()
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not episode_active.is_set():
|
||||
time.sleep(0.01)
|
||||
interpolator.reset()
|
||||
hz_tracker.reset()
|
||||
last_action_consume_time = 0
|
||||
continue
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Consume new action from RTC queue at effective_policy_fps rate
|
||||
current_time = time.perf_counter()
|
||||
if current_time - last_action_consume_time >= effective_policy_interval:
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
interpolator.update(action)
|
||||
last_action_consume_time = current_time
|
||||
|
||||
# Record to dataset at action consume rate
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action):
|
||||
action_dict[key] = action[i].item()
|
||||
|
||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||
|
||||
frame = {}
|
||||
for key, value in obs_processed.items():
|
||||
frame[f"observation.{key}"] = value
|
||||
for key, value in action_for_dataset.items():
|
||||
frame[f"action.{key}"] = value
|
||||
frame["task"] = cfg.task
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Get interpolated action and send to robot at robot_fps (highest rate)
|
||||
interp_action, velocity = interpolator.get_interpolated_action()
|
||||
|
||||
if interp_action is not None:
|
||||
# Convert tensor to dict
|
||||
action_dict = {}
|
||||
velocity_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(interp_action):
|
||||
action_dict[key] = interp_action[i].item()
|
||||
if velocity is not None:
|
||||
# Convert to full motor name for velocity feedforward
|
||||
motor_name = key.replace(".pos", "").replace(".", "_")
|
||||
# Actually the key format is like "right_joint_1.pos"
|
||||
motor_name = key.removesuffix(".pos")
|
||||
velocity_dict[motor_name] = velocity[i].item()
|
||||
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed, velocity_ff=velocity_dict)
|
||||
action_count += 1
|
||||
|
||||
hz_tracker.tick()
|
||||
|
||||
# Maintain robot control rate
|
||||
dt_s = time.perf_counter() - start_time
|
||||
sleep_time = max(0, robot_interval - dt_s - 0.001)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
final_hz = hz_tracker.get_avg_hz()
|
||||
if final_hz:
|
||||
logger.info(f"[ACTOR] Final robot Hz: {final_hz:.1f}")
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def build_custom_gains(robot: OpenArmsFollower, kp_scale: float | None, kd_scale: float | None) -> tuple[dict | None, dict | None]:
|
||||
"""Build custom PID gains dict from robot config."""
|
||||
if kp_scale is None and kd_scale is None:
|
||||
return None, None
|
||||
|
||||
custom_kp = {}
|
||||
custom_kd = {}
|
||||
for arm in ["right", "left"]:
|
||||
bus = robot.bus_right if arm == "right" else robot.bus_left
|
||||
for i, motor_name in enumerate(bus.motors):
|
||||
full_name = f"{arm}_{motor_name}"
|
||||
default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp
|
||||
default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd
|
||||
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
|
||||
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
|
||||
return custom_kp, custom_kd
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: OpenArmsRTCInterpEvalConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method."""
|
||||
if policy.name in ["pi05", "pi0"]:
|
||||
return policy
|
||||
|
||||
try:
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Evaluation Function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: OpenArmsRTCInterpEvalConfig):
|
||||
"""Main evaluation function with RTC + Interpolation."""
|
||||
init_logging()
|
||||
|
||||
print("=" * 70)
|
||||
print("OpenArms Policy Evaluation with RTC + Interpolation")
|
||||
print("=" * 70)
|
||||
print(f"\nModel: {cfg.model_id}")
|
||||
print(f"Evaluation Dataset: {cfg.eval_dataset_id}")
|
||||
print(f"Task: {cfg.task}")
|
||||
print(f"\n--- Timing ---")
|
||||
print(f"Camera FPS: {cfg.camera_fps} (hardware limit)")
|
||||
print(f"Policy trained at: {cfg.policy_fps}Hz")
|
||||
print(f"Speed multiplier: {cfg.speed_multiplier}x")
|
||||
print(f"Effective policy FPS: {cfg.effective_policy_fps}Hz (action consume rate)")
|
||||
print(f"Robot FPS: {cfg.robot_fps}Hz (interpolated commands)")
|
||||
print(f"\n--- RTC ---")
|
||||
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||
print(f"Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||
print(f"Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||
print(f"\n--- PID Tuning ---")
|
||||
print(f"KP scale: {cfg.kp_scale}")
|
||||
print(f"KD scale: {cfg.kd_scale}")
|
||||
print(f"Velocity feedforward: {cfg.use_velocity_ff}")
|
||||
print(f"\n--- Episodes ---")
|
||||
print(f"Episodes: {cfg.num_episodes}")
|
||||
print(f"Duration: {cfg.episode_time_sec}s per episode")
|
||||
print(f"Device: {cfg.device}")
|
||||
print("=" * 70)
|
||||
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
episode_active = Event()
|
||||
|
||||
# Initialize Robot
|
||||
camera_config = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=int(cfg.camera_fps)),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=int(cfg.camera_fps)),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=int(cfg.camera_fps)),
|
||||
}
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=cfg.follower_left_port,
|
||||
port_right=cfg.follower_right_port,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=camera_config,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
|
||||
if not follower.is_connected:
|
||||
raise RuntimeError("Follower robot failed to connect!")
|
||||
|
||||
# Build custom PID gains
|
||||
custom_kp, custom_kd = build_custom_gains(follower, cfg.kp_scale, cfg.kd_scale)
|
||||
if custom_kp:
|
||||
logger.info(f"Custom gains: kp_scale={cfg.kp_scale}, kd_scale={cfg.kd_scale}")
|
||||
if cfg.use_velocity_ff:
|
||||
logger.info("Velocity feedforward enabled")
|
||||
|
||||
robot = RobotWrapper(
|
||||
follower,
|
||||
custom_kp=custom_kp,
|
||||
custom_kd=custom_kd,
|
||||
use_velocity_ff=cfg.use_velocity_ff,
|
||||
)
|
||||
logger.info("Follower robot connected")
|
||||
|
||||
# Build Processors and Dataset Features
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
action_features_hw = {}
|
||||
for key, value in follower.action_features.items():
|
||||
if key.endswith(".pos"):
|
||||
action_features_hw[key] = value
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create or Load Dataset
|
||||
dataset = None
|
||||
dataset_lock = Lock()
|
||||
|
||||
if cfg.record_dataset:
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / cfg.eval_dataset_id
|
||||
if dataset_path.exists():
|
||||
logger.info(f"Evaluation dataset exists at: {dataset_path}")
|
||||
logger.info("New episodes will be appended.")
|
||||
choice = input("Continue? (y/n): ").strip().lower()
|
||||
if choice != "y":
|
||||
logger.info("Aborting evaluation.")
|
||||
follower.disconnect()
|
||||
return
|
||||
|
||||
# Dataset uses effective policy FPS
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=cfg.eval_dataset_id,
|
||||
fps=cfg.effective_policy_fps,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
logger.info(f"Dataset created: {cfg.eval_dataset_id} at {cfg.effective_policy_fps}Hz")
|
||||
|
||||
# Load Policy
|
||||
logger.info(f"Loading policy from: {cfg.model_id}")
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type in ["pi05", "pi0"]:
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
# Create Action Queue, Interpolator, and Hz Tracker
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
interpolator = ActionInterpolator(robot_fps=int(cfg.robot_fps))
|
||||
hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0)
|
||||
|
||||
# Start Threads
|
||||
get_actions_t = Thread(
|
||||
target=get_actions_thread,
|
||||
args=(
|
||||
policy,
|
||||
robot,
|
||||
robot_observation_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_t.start()
|
||||
logger.info("Started RTC action generation thread")
|
||||
|
||||
actor_t = Thread(
|
||||
target=actor_thread,
|
||||
args=(
|
||||
robot,
|
||||
robot_action_processor,
|
||||
action_queue,
|
||||
shutdown_event,
|
||||
cfg,
|
||||
episode_active,
|
||||
dataset,
|
||||
dataset_lock,
|
||||
teleop_action_processor,
|
||||
robot_observation_processor,
|
||||
interpolator,
|
||||
hz_tracker,
|
||||
),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_t.start()
|
||||
logger.info(f"Started actor thread with interpolation at {cfg.robot_fps}Hz")
|
||||
|
||||
# Run Evaluation Episodes
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
while episode_idx < cfg.num_episodes and not shutdown_event.is_set():
|
||||
log_say(f"Evaluating episode {episode_idx + 1} of {cfg.num_episodes}")
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f"Episode {episode_idx + 1} / {cfg.num_episodes}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
interpolator.reset()
|
||||
hz_tracker.reset()
|
||||
episode_active.set()
|
||||
episode_start_time = time.time()
|
||||
|
||||
while (time.time() - episode_start_time) < cfg.episode_time_sec:
|
||||
if shutdown_event.is_set():
|
||||
break
|
||||
|
||||
elapsed = time.time() - episode_start_time
|
||||
if int(elapsed) % 30 == 0 and int(elapsed) > 0:
|
||||
logger.info(
|
||||
f"[MAIN] Episode progress: {elapsed:.0f}/{cfg.episode_time_sec}s, "
|
||||
f"queue_size={action_queue.qsize()}"
|
||||
)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
episode_active.clear()
|
||||
logger.info(f"Episode {episode_idx + 1} completed")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
with dataset_lock:
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
logger.info(
|
||||
f"Saving episode {episode_idx + 1} "
|
||||
f"({dataset.episode_buffer['size']} frames)"
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
episode_idx += 1
|
||||
|
||||
# Manual reset between episodes
|
||||
if not shutdown_event.is_set() and episode_idx < cfg.num_episodes:
|
||||
log_say("Waiting for manual reset")
|
||||
logger.info("Manually reset the environment and press ENTER to continue")
|
||||
input("Press ENTER when ready...")
|
||||
|
||||
logger.info(f"Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n\nEvaluation interrupted by user")
|
||||
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
episode_active.clear()
|
||||
|
||||
if get_actions_t.is_alive():
|
||||
logger.info("Waiting for action generation thread to finish...")
|
||||
get_actions_t.join(timeout=5.0)
|
||||
|
||||
if actor_t.is_alive():
|
||||
logger.info("Waiting for actor thread to finish...")
|
||||
actor_t.join(timeout=5.0)
|
||||
|
||||
follower.disconnect()
|
||||
logger.info("Follower disconnected")
|
||||
|
||||
if cfg.record_dataset and dataset is not None:
|
||||
dataset.finalize()
|
||||
if cfg.push_to_hub:
|
||||
logger.info("Uploading to Hugging Face Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -344,15 +344,10 @@ class OpenArmsFollower(Robot):
|
||||
obs_dict[cam_key] = frame
|
||||
except TimeoutError:
|
||||
# If no new frame available, reuse last valid frame from cache
|
||||
# This prevents blocking the entire control loop on slow USB reads
|
||||
if self.camera_frame_cache[cam_key] is not None:
|
||||
obs_dict[cam_key] = self.camera_frame_cache[cam_key]
|
||||
logger.debug(f"Camera {cam_key} timeout, reusing cached frame")
|
||||
else:
|
||||
# First frame not available yet - use blocking read to ensure we get a frame
|
||||
logger.warning(f"Camera {cam_key} no cached frame, doing blocking read")
|
||||
frame = cam.read()
|
||||
self.camera_frame_cache[cam_key] = frame
|
||||
obs_dict[cam_key] = frame
|
||||
|
||||
# Store timing with padded name to align output (e.g. "left_wrist ")
|
||||
timings[f"{cam_key:14s}"] = (time.perf_counter() - t0) * 1000
|
||||
|
||||
Reference in New Issue
Block a user