diff --git a/examples/rac/rac_data_collection_openarms_rtc.py b/examples/rac/rac_data_collection_openarms_rtc.py index 9efa6d88f..030db0bc8 100644 --- a/examples/rac/rac_data_collection_openarms_rtc.py +++ b/examples/rac/rac_data_collection_openarms_rtc.py @@ -257,7 +257,7 @@ def get_actions_thread( policy, robot: RobotWrapper, robot_observation_processor, - action_queue: ActionQueue, + action_queue_holder: dict, shutdown_event: Event, cfg: RaCRTCConfig, policy_active: Event, @@ -289,6 +289,7 @@ def get_actions_thread( time.sleep(0.01) continue + action_queue = action_queue_holder["queue"] if action_queue.qsize() <= get_actions_threshold: current_time = time.perf_counter() action_index_before_inference = action_queue.get_action_index() @@ -429,11 +430,11 @@ def main(cfg: RaCRTCConfig): policy.eval() logger.info(f"Policy loaded: {policy.name}") - action_queue = ActionQueue(cfg.rtc) + action_queue_holder = {"queue": ActionQueue(cfg.rtc)} get_actions_t = Thread( target=get_actions_thread, - args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active, fps), + args=(policy, robot, obs_proc, action_queue_holder, shutdown_event, cfg, policy_active, fps), daemon=True, name="GetActions", ) @@ -462,7 +463,8 @@ def main(cfg: RaCRTCConfig): while recorded < cfg.dataset.num_episodes and not events["stop_recording"]: log_say(f"RaC episode {recorded + 1}", play_sounds=cfg.play_sounds) - action_queue = ActionQueue(cfg.rtc) + action_queue_holder["queue"] = ActionQueue(cfg.rtc) + action_queue = action_queue_holder["queue"] events["policy_paused"] = False events["correction_active"] = False events["start_correction"] = False