mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
add interpolation option
This commit is contained in:
@@ -79,6 +79,7 @@ DEFAULT_TASK_DESCRIPTION = "three-folds-dataset"
|
|||||||
|
|
||||||
DEFAULT_NUM_EPISODES = 1
|
DEFAULT_NUM_EPISODES = 1
|
||||||
DEFAULT_FPS = 30
|
DEFAULT_FPS = 30
|
||||||
|
DEFAULT_ROBOT_HZ = 50
|
||||||
DEFAULT_EPISODE_TIME_SEC = 300
|
DEFAULT_EPISODE_TIME_SEC = 300
|
||||||
DEFAULT_RESET_TIME_SEC = 60
|
DEFAULT_RESET_TIME_SEC = 60
|
||||||
|
|
||||||
@@ -167,6 +168,9 @@ class OpenArmsRTCEvalConfig(HubMixin):
|
|||||||
record_dataset: bool = True
|
record_dataset: bool = True
|
||||||
push_to_hub: bool = True
|
push_to_hub: bool = True
|
||||||
|
|
||||||
|
interpolation: bool = False
|
||||||
|
robot_hz: float = DEFAULT_ROBOT_HZ
|
||||||
|
|
||||||
use_torch_compile: bool = False
|
use_torch_compile: bool = False
|
||||||
torch_compile_backend: str = "inductor"
|
torch_compile_backend: str = "inductor"
|
||||||
torch_compile_mode: str = "default"
|
torch_compile_mode: str = "default"
|
||||||
@@ -323,54 +327,79 @@ def actor_thread(
|
|||||||
):
|
):
|
||||||
"""Thread function to execute actions on the robot."""
|
"""Thread function to execute actions on the robot."""
|
||||||
try:
|
try:
|
||||||
logger.info("[ACTOR] Starting actor thread")
|
|
||||||
|
|
||||||
action_count = 0
|
|
||||||
action_interval = 1.0 / cfg.fps
|
|
||||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||||
|
action_count = 0
|
||||||
|
|
||||||
|
if cfg.interpolation:
|
||||||
|
robot_interval = 1.0 / cfg.robot_hz
|
||||||
|
interp_steps = int(cfg.robot_hz / cfg.fps)
|
||||||
|
logger.info(f"[ACTOR] Interpolation ON: policy={cfg.fps}Hz -> robot={cfg.robot_hz}Hz ({interp_steps} steps)")
|
||||||
|
else:
|
||||||
|
robot_interval = 1.0 / cfg.fps
|
||||||
|
interp_steps = 1
|
||||||
|
logger.info(f"[ACTOR] Interpolation OFF: policy={cfg.fps}Hz, robot={cfg.fps}Hz")
|
||||||
|
|
||||||
|
prev_action: Tensor | None = None
|
||||||
|
current_action: Tensor | None = None
|
||||||
|
interp_step = 0
|
||||||
|
last_dataset_time = 0.0
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
while not shutdown_event.is_set():
|
||||||
if not episode_active.is_set():
|
if not episode_active.is_set():
|
||||||
|
prev_action = None
|
||||||
|
current_action = None
|
||||||
|
interp_step = 0
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
action = action_queue.get()
|
|
||||||
|
|
||||||
if action is not None:
|
if interp_step == 0 or current_action is None:
|
||||||
action = action.cpu()
|
new_action = action_queue.get()
|
||||||
|
if new_action is not None:
|
||||||
|
prev_action = current_action if current_action is not None else new_action.cpu()
|
||||||
|
current_action = new_action.cpu()
|
||||||
|
interp_step = 0
|
||||||
|
|
||||||
|
if current_action is not None:
|
||||||
|
if cfg.interpolation and prev_action is not None and interp_steps > 1:
|
||||||
|
alpha = (interp_step + 1) / interp_steps
|
||||||
|
action_to_send = prev_action + alpha * (current_action - prev_action)
|
||||||
|
else:
|
||||||
|
action_to_send = current_action
|
||||||
|
|
||||||
action_dict = {}
|
action_dict = {}
|
||||||
for i, key in enumerate(action_keys):
|
for i, key in enumerate(action_keys):
|
||||||
if i < len(action):
|
if i < len(action_to_send):
|
||||||
action_dict[key] = action[i].item()
|
action_dict[key] = action_to_send[i].item()
|
||||||
|
|
||||||
action_processed = robot_action_processor((action_dict, None))
|
action_processed = robot_action_processor((action_dict, None))
|
||||||
robot.send_action(action_processed)
|
robot.send_action(action_processed)
|
||||||
|
action_count += 1
|
||||||
|
interp_step = (interp_step + 1) % interp_steps
|
||||||
|
|
||||||
if cfg.record_dataset and dataset is not None:
|
if cfg.record_dataset and dataset is not None:
|
||||||
with dataset_lock:
|
now = time.perf_counter()
|
||||||
obs = robot.get_observation()
|
if now - last_dataset_time >= (1.0 / cfg.fps):
|
||||||
obs_processed = robot_observation_processor(obs)
|
last_dataset_time = now
|
||||||
action_for_dataset = teleop_action_processor((action_dict, None))
|
with dataset_lock:
|
||||||
|
obs = robot.get_observation()
|
||||||
frame = {}
|
obs_processed = robot_observation_processor(obs)
|
||||||
for key, value in obs_processed.items():
|
action_for_dataset = teleop_action_processor((action_dict, None))
|
||||||
frame[f"observation.{key}"] = value
|
frame = {}
|
||||||
for key, value in action_for_dataset.items():
|
for key, value in obs_processed.items():
|
||||||
frame[f"action.{key}"] = value
|
frame[f"observation.{key}"] = value
|
||||||
frame["task"] = cfg.task
|
for key, value in action_for_dataset.items():
|
||||||
|
frame[f"action.{key}"] = value
|
||||||
dataset.add_frame(frame)
|
frame["task"] = cfg.task
|
||||||
|
dataset.add_frame(frame)
|
||||||
action_count += 1
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_time
|
dt_s = time.perf_counter() - start_time
|
||||||
sleep_time = max(0, action_interval - dt_s - 0.001)
|
sleep_time = max(0, robot_interval - dt_s - 0.001)
|
||||||
if sleep_time > 0:
|
if sleep_time > 0:
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
logger.info(f"[ACTOR] Shutting down. Total actions: {action_count}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[ACTOR] Fatal exception: {e}")
|
logger.error(f"[ACTOR] Fatal exception: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
@@ -434,6 +463,9 @@ def main(cfg: OpenArmsRTCEvalConfig):
|
|||||||
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
print(f"RTC Enabled: {cfg.rtc.enabled}")
|
||||||
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
|
||||||
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
|
||||||
|
print(f"Policy Hz: {cfg.fps}")
|
||||||
|
print(f"Robot Hz: {cfg.robot_hz if cfg.interpolation else cfg.fps}")
|
||||||
|
print(f"Interpolation: {cfg.interpolation}")
|
||||||
print(f"Device: {cfg.device}")
|
print(f"Device: {cfg.device}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user