diff --git a/examples/openarms/evaluate_ee.py b/examples/openarms/evaluate_ee.py index 39505e6be..b34744769 100644 --- a/examples/openarms/evaluate_ee.py +++ b/examples/openarms/evaluate_ee.py @@ -42,13 +42,16 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.model.kinematics import RobotKinematics from lerobot.policies.factory import make_policy, make_pre_post_processors +from lerobot.policies.utils import make_robot_action from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors from lerobot.utils.constants import ACTION, OBS_STR +from lerobot.utils.control_utils import predict_action from lerobot.utils.relative_actions import ( convert_state_to_relative, convert_from_relative_actions, PerTimestepNormalizer, ) +from lerobot.utils.utils import get_safe_torch_device from lerobot.utils.visualization_utils import init_rerun, log_rerun_data from lerobot.processor.converters import ( robot_action_observation_to_transition, @@ -278,7 +281,15 @@ def run_ee_inference_loop( display_data: bool = True, ): """Run inference loop with EE conversion and optional UMI-style relative actions.""" + device = get_safe_torch_device(policy.config.device) + + # Reset policy and processors + policy.reset() + preprocessor.reset() + postprocessor.reset() + dt = 1.0 / fps + timestamp = 0 start_time = time.perf_counter() step = 0 @@ -289,12 +300,13 @@ def run_ee_inference_loop( mode_str += " [relative state]" print(f"\nRunning EE inference for {duration_s}s...{mode_str}") - while time.perf_counter() - start_time < duration_s: - if events.get("exit_early") or events.get("stop_recording"): - break - + while timestamp < duration_s: loop_start = time.perf_counter() + if events.get("exit_early"): + events["exit_early"] = False + break + # 1. Get robot observation (joint positions) robot_obs = robot.get_observation() @@ -309,44 +321,46 @@ def run_ee_inference_loop( # Store current EE position for relative action conversion current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())]) - # 3. Build policy input with EE state + # 3. Build observation frame with EE state for policy input + # Get state names from dataset features to build state array + state_names = dataset.features.get("observation.state", {}).get("names", []) + + # Build state array from EE values (sorted to match training order) + ee_keys = sorted(ee_state.keys()) + ee_values = [ee_state[k] for k in ee_keys] + # Convert to relative state if enabled (UMI-style) if use_relative_state: - ee_state_tensor = torch.tensor([ee_state[k] for k in sorted(ee_state.keys())]) + ee_state_tensor = torch.tensor(ee_values) relative_state = convert_state_to_relative(ee_state_tensor.unsqueeze(0)) - ee_state = {k: float(relative_state[0, i]) for i, k in enumerate(sorted(ee_state.keys()))} + ee_values = [float(relative_state[0, i]) for i in range(len(ee_values))] - policy_obs = {"observation.state": ee_state, "task": task} + # Build observation dict for policy (images + state as numpy arrays) + observation_frame = {} # Add images - for cam_name in robot.cameras: - img = robot_obs.get(f"{cam_name}.image") - if img is not None: - policy_obs[f"observation.images.{cam_name}"] = img + for key, value in robot_obs.items(): + if ".image" in key: + obs_key = f"observation.images.{key.replace('.image', '')}" + observation_frame[obs_key] = value - # 4. Preprocess and run policy - batch = preprocessor(policy_obs) + # Add state as numpy array + observation_frame["observation.state"] = np.array(ee_values, dtype=np.float32) - # Add batch dimension if needed - for key in batch: - if isinstance(batch[key], torch.Tensor) and batch[key].dim() == 1: - batch[key] = batch[key].unsqueeze(0) - elif isinstance(batch[key], torch.Tensor) and batch[key].dim() == 3: - batch[key] = batch[key].unsqueeze(0) + # 4. Run policy inference using predict_action + action_tensor = predict_action( + observation=observation_frame, + policy=policy, + device=device, + preprocessor=preprocessor, + postprocessor=postprocessor, + use_amp=policy.config.use_amp, + task=task, + robot_type=robot.robot_type, + ) - with torch.inference_mode(): - action_tensor = policy.select_action(batch) - - # 5. Postprocess and convert tensor to dict - action_tensor = postprocessor(action_tensor) - - # Flatten to 1D: take first timestep if chunks, squeeze batch dims - while action_tensor.dim() > 1: - action_tensor = action_tensor[0] - - # Convert tensor to dict using action names from dataset - action_names = dataset.features[ACTION]["names"] - ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)} + # 5. Convert action tensor to dict + ee_action = make_robot_action(action_tensor, dataset.features) # 6. Convert relative action back to absolute if needed if use_relative_actions: @@ -370,11 +384,11 @@ def run_ee_inference_loop( # 8. Send joint commands to robot robot.send_action(joint_action) - # 9. Save frame to dataset + # 9. Save frame to dataset (save original robot obs + joint action) if dataset is not None: - observation_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR) - action_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION) - frame = {**observation_frame, **action_frame, "task": task} + obs_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR) + act_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION) + frame = {**obs_frame, **act_frame, "task": task} dataset.add_frame(frame) # 10. Visualization @@ -392,6 +406,8 @@ def run_ee_inference_loop( sleep_time = dt - loop_duration if sleep_time > 0: precise_sleep(sleep_time) + + timestamp = time.perf_counter() - start_time print(f" Completed {step} steps")