fix at 2x actual freq

This commit is contained in:
Pepijn
2026-01-09 13:03:29 +01:00
parent 7ac05c838d
commit 7d6f113072
+42 -24
View File
@@ -79,7 +79,6 @@ DEFAULT_TASK_DESCRIPTION = "three-folds-dataset"
DEFAULT_NUM_EPISODES = 1
DEFAULT_FPS = 30
DEFAULT_ROBOT_HZ = 50
DEFAULT_EPISODE_TIME_SEC = 300
DEFAULT_RESET_TIME_SEC = 60
@@ -169,7 +168,6 @@ class OpenArmsRTCEvalConfig(HubMixin):
push_to_hub: bool = True
interpolation: bool = False
robot_hz: float = DEFAULT_ROBOT_HZ
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
@@ -328,45 +326,56 @@ def actor_thread(
"""Thread function to execute actions on the robot."""
try:
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
action_count = 0
if cfg.interpolation:
robot_interval = 1.0 / cfg.robot_hz
interp_steps = int(cfg.robot_hz / cfg.fps)
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.robot_hz}Hz ({interp_steps} steps)")
interp_factor = 2
robot_interval = 1.0 / (cfg.fps * interp_factor)
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.fps * interp_factor}Hz (2x)")
else:
interp_factor = 1
robot_interval = 1.0 / cfg.fps
interp_steps = 1
logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz")
prev_action: Tensor | None = None
current_action: Tensor | None = None
interp_step = 0
interpolated_actions: list[Tensor] = []
interp_idx = 0
robot_send_count = 0
policy_consume_count = 0
last_hz_print = time.perf_counter()
last_dataset_time = 0.0
while not shutdown_event.is_set():
if not episode_active.is_set():
prev_action = None
current_action = None
interp_step = 0
interpolated_actions = []
interp_idx = 0
robot_send_count = 0
policy_consume_count = 0
last_hz_print = time.perf_counter()
time.sleep(0.01)
continue
start_time = time.perf_counter()
if interp_step == 0 or current_action is None:
if interp_idx >= len(interpolated_actions):
new_action = action_queue.get()
if new_action is not None:
prev_action = current_action if current_action is not None else new_action.cpu()
current_action = new_action.cpu()
interp_step = 0
policy_consume_count += 1
if current_action is not None:
if cfg.interpolation and prev_action is not None and interp_steps > 1:
alpha = (interp_step + 1) / interp_steps
action_to_send = prev_action + alpha * (current_action - prev_action)
else:
action_to_send = current_action
if cfg.interpolation and prev_action is not None:
mid = prev_action + 0.5 * (current_action - prev_action)
interpolated_actions = [mid, current_action]
else:
interpolated_actions = [current_action]
prev_action = current_action
interp_idx = 0
if interp_idx < len(interpolated_actions):
action_to_send = interpolated_actions[interp_idx]
interp_idx += 1
action_dict = {}
for i, key in enumerate(action_keys):
@@ -375,8 +384,7 @@ def actor_thread(
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_count += 1
interp_step = (interp_step + 1) % interp_steps
robot_send_count += 1
if cfg.record_dataset and dataset is not None:
now = time.perf_counter()
@@ -394,12 +402,22 @@ def actor_thread(
frame["task"] = cfg.task
dataset.add_frame(frame)
now = time.perf_counter()
if now - last_hz_print >= 5.0:
elapsed = now - last_hz_print
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}")
robot_send_count = 0
policy_consume_count = 0
last_hz_print = now
dt_s = time.perf_counter() - start_time
sleep_time = max(0, robot_interval - dt_s - 0.001)
if sleep_time > 0:
time.sleep(sleep_time)
logger.info(f"[ACTOR] Shutting down. Total actions: {action_count}")
logger.info("[ACTOR] Shutting down")
except Exception as e:
logger.error(f"[ACTOR] Fatal exception: {e}")
logger.error(traceback.format_exc())
@@ -464,7 +482,7 @@ def main(cfg: OpenArmsRTCEvalConfig):
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
print(f"Policy Hz: {cfg.fps}")
print(f"Robot Hz: {cfg.robot_hz if cfg.interpolation else cfg.fps}")
print(f"Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
print(f"Interpolation: {cfg.interpolation}")
print(f"Device: {cfg.device}")
print("=" * 60)