interpolation

This commit is contained in:
Pepijn
2026-01-09 11:49:25 +01:00
parent c85f1692d6
commit 498e215444
+130 -26
View File
@@ -33,6 +33,16 @@ Example usage:
python examples/openarms/evaluate_with_rtc.py \
--rtc.execution_horizon=12 \
--rtc.max_guidance_weight=10.0
# With custom P/D gain scaling
python examples/openarms/evaluate_with_rtc.py \
--kp_scale=0.8 \
--kd_scale=1.2
# With action interpolation (policy at 30Hz, robot at 50Hz)
python examples/openarms/evaluate_with_rtc.py \
--action_interpolation_enabled=true \
--control_hz=50
"""
import logging
@@ -82,6 +92,10 @@ DEFAULT_FPS = 30
DEFAULT_EPISODE_TIME_SEC = 300
DEFAULT_RESET_TIME_SEC = 60
DEFAULT_KP_SCALE = 1.0
DEFAULT_KD_SCALE = 1.0
DEFAULT_CONTROL_HZ = 50
DEFAULT_FOLLOWER_LEFT_PORT = "can0"
DEFAULT_FOLLOWER_RIGHT_PORT = "can1"
@@ -104,13 +118,17 @@ class RobotWrapper:
self.robot = robot
self.lock = Lock()
@property
def config(self):
return self.robot.config
def get_observation(self) -> dict[str, Tensor]:
with self.lock:
return self.robot.get_observation()
def send_action(self, action: dict) -> None:
def send_action(self, action: dict, custom_kp: dict | None = None, custom_kd: dict | None = None) -> None:
with self.lock:
self.robot.send_action(action)
self.robot.send_action(action, custom_kp=custom_kp, custom_kd=custom_kd)
@property
def observation_features(self) -> dict:
@@ -167,6 +185,12 @@ class OpenArmsRTCEvalConfig(HubMixin):
record_dataset: bool = True
push_to_hub: bool = True
kp_scale: float = DEFAULT_KP_SCALE
kd_scale: float = DEFAULT_KD_SCALE
action_interpolation_enabled: bool = False
control_hz: float = DEFAULT_CONTROL_HZ
use_torch_compile: bool = False
torch_compile_backend: str = "inductor"
torch_compile_mode: str = "default"
@@ -309,6 +333,26 @@ def get_actions_thread(
# ============================================================================
def _build_scaled_gains(robot: RobotWrapper, kp_scale: float, kd_scale: float) -> tuple[dict, dict]:
"""Build scaled kp/kd dicts for all motors."""
motor_names = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7", "gripper"]
custom_kp = {}
custom_kd = {}
for i, motor_name in enumerate(motor_names):
kp = robot.config.position_kp[i] * kp_scale
kd = robot.config.position_kd[i] * kd_scale
custom_kp[f"right_{motor_name}"] = kp
custom_kp[f"left_{motor_name}"] = kp
custom_kd[f"right_{motor_name}"] = kd
custom_kd[f"left_{motor_name}"] = kd
return custom_kp, custom_kd
def _interpolate_actions(prev_action: Tensor, next_action: Tensor, alpha: float) -> Tensor:
"""Linear interpolation between two action tensors."""
return prev_action + alpha * (next_action - prev_action)
def actor_thread(
robot: RobotWrapper,
robot_action_processor,
@@ -324,49 +368,104 @@ def actor_thread(
"""Thread function to execute actions on the robot."""
try:
logger.info("[ACTOR] Starting actor thread")
logger.info(f"[ACTOR] kp_scale={cfg.kp_scale}, kd_scale={cfg.kd_scale}")
logger.info(f"[ACTOR] interpolation={cfg.action_interpolation_enabled}, control_hz={cfg.control_hz}")
custom_kp, custom_kd = _build_scaled_gains(robot, cfg.kp_scale, cfg.kd_scale)
action_count = 0
action_interval = 1.0 / cfg.fps
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
if cfg.action_interpolation_enabled:
control_interval = 1.0 / cfg.control_hz
interp_steps = int(cfg.control_hz / cfg.fps)
else:
control_interval = 1.0 / cfg.fps
interp_steps = 1
prev_action: Tensor | None = None
current_action: Tensor | None = None
interp_step = 0
last_dataset_frame_time = 0.0
while not shutdown_event.is_set():
if not episode_active.is_set():
prev_action = None
current_action = None
interp_step = 0
time.sleep(0.01)
continue
start_time = time.perf_counter()
action = action_queue.get()
if action is not None:
action = action.cpu()
if cfg.action_interpolation_enabled:
if interp_step == 0 or current_action is None:
new_action = action_queue.get()
if new_action is not None:
prev_action = current_action if current_action is not None else new_action.cpu()
current_action = new_action.cpu()
interp_step = 0
action_dict = {}
for i, key in enumerate(action_keys):
if i < len(action):
action_dict[key] = action[i].item()
if current_action is not None:
if prev_action is not None and interp_steps > 1:
alpha = (interp_step + 1) / interp_steps
action_to_send = _interpolate_actions(prev_action, current_action, alpha)
else:
action_to_send = current_action
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed)
action_dict = {}
for i, key in enumerate(action_keys):
if i < len(action_to_send):
action_dict[key] = action_to_send[i].item()
if cfg.record_dataset and dataset is not None:
with dataset_lock:
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
action_for_dataset = teleop_action_processor((action_dict, None))
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed, custom_kp=custom_kp, custom_kd=custom_kd)
action_count += 1
frame = {}
for key, value in obs_processed.items():
frame[f"observation.{key}"] = value
for key, value in action_for_dataset.items():
frame[f"action.{key}"] = value
frame["task"] = cfg.task
interp_step = (interp_step + 1) % interp_steps
dataset.add_frame(frame)
if cfg.record_dataset and dataset is not None:
if time.perf_counter() - last_dataset_frame_time >= (1.0 / cfg.fps):
last_dataset_frame_time = time.perf_counter()
with dataset_lock:
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
action_for_dataset = teleop_action_processor((action_dict, None))
frame = {}
for key, value in obs_processed.items():
frame[f"observation.{key}"] = value
for key, value in action_for_dataset.items():
frame[f"action.{key}"] = value
frame["task"] = cfg.task
dataset.add_frame(frame)
else:
action = action_queue.get()
if action is not None:
action = action.cpu()
action_dict = {}
for i, key in enumerate(action_keys):
if i < len(action):
action_dict[key] = action[i].item()
action_count += 1
action_processed = robot_action_processor((action_dict, None))
robot.send_action(action_processed, custom_kp=custom_kp, custom_kd=custom_kd)
action_count += 1
if cfg.record_dataset and dataset is not None:
with dataset_lock:
obs = robot.get_observation()
obs_processed = robot_observation_processor(obs)
action_for_dataset = teleop_action_processor((action_dict, None))
frame = {}
for key, value in obs_processed.items():
frame[f"observation.{key}"] = value
for key, value in action_for_dataset.items():
frame[f"action.{key}"] = value
frame["task"] = cfg.task
dataset.add_frame(frame)
dt_s = time.perf_counter() - start_time
sleep_time = max(0, action_interval - dt_s - 0.001)
sleep_time = max(0, control_interval - dt_s - 0.001)
if sleep_time > 0:
time.sleep(sleep_time)
@@ -434,6 +533,11 @@ def main(cfg: OpenArmsRTCEvalConfig):
print(f"RTC Enabled: {cfg.rtc.enabled}")
print(f"RTC Execution Horizon: {cfg.rtc.execution_horizon}")
print(f"RTC Max Guidance Weight: {cfg.rtc.max_guidance_weight}")
print(f"Kp Scale: {cfg.kp_scale}")
print(f"Kd Scale: {cfg.kd_scale}")
print(f"Action Interpolation: {cfg.action_interpolation_enabled}")
if cfg.action_interpolation_enabled:
print(f"Control Hz: {cfg.control_hz}")
print(f"Device: {cfg.device}")
print("=" * 60)