make flow similar to evaluate.py

This commit is contained in:
Pepijn
2026-01-08 17:15:14 +01:00
parent a9cf770b99
commit 3ebeb59cdc
+53 -37
View File
@@ -42,13 +42,16 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
from lerobot.model.kinematics import RobotKinematics
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.policies.utils import make_robot_action
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import predict_action
from lerobot.utils.relative_actions import (
convert_state_to_relative,
convert_from_relative_actions,
PerTimestepNormalizer,
)
from lerobot.utils.utils import get_safe_torch_device
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
from lerobot.processor.converters import (
robot_action_observation_to_transition,
@@ -278,7 +281,15 @@ def run_ee_inference_loop(
display_data: bool = True,
):
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
device = get_safe_torch_device(policy.config.device)
# Reset policy and processors
policy.reset()
preprocessor.reset()
postprocessor.reset()
dt = 1.0 / fps
timestamp = 0
start_time = time.perf_counter()
step = 0
@@ -289,12 +300,13 @@ def run_ee_inference_loop(
mode_str += " [relative state]"
print(f"\nRunning EE inference for {duration_s}s...{mode_str}")
while time.perf_counter() - start_time < duration_s:
if events.get("exit_early") or events.get("stop_recording"):
break
while timestamp < duration_s:
loop_start = time.perf_counter()
if events.get("exit_early"):
events["exit_early"] = False
break
# 1. Get robot observation (joint positions)
robot_obs = robot.get_observation()
@@ -309,44 +321,46 @@ def run_ee_inference_loop(
# Store current EE position for relative action conversion
current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())])
# 3. Build policy input with EE state
# 3. Build observation frame with EE state for policy input
# Get state names from dataset features to build state array
state_names = dataset.features.get("observation.state", {}).get("names", [])
# Build state array from EE values (sorted to match training order)
ee_keys = sorted(ee_state.keys())
ee_values = [ee_state[k] for k in ee_keys]
# Convert to relative state if enabled (UMI-style)
if use_relative_state:
ee_state_tensor = torch.tensor([ee_state[k] for k in sorted(ee_state.keys())])
ee_state_tensor = torch.tensor(ee_values)
relative_state = convert_state_to_relative(ee_state_tensor.unsqueeze(0))
ee_state = {k: float(relative_state[0, i]) for i, k in enumerate(sorted(ee_state.keys()))}
ee_values = [float(relative_state[0, i]) for i in range(len(ee_values))]
policy_obs = {"observation.state": ee_state, "task": task}
# Build observation dict for policy (images + state as numpy arrays)
observation_frame = {}
# Add images
for cam_name in robot.cameras:
img = robot_obs.get(f"{cam_name}.image")
if img is not None:
policy_obs[f"observation.images.{cam_name}"] = img
for key, value in robot_obs.items():
if ".image" in key:
obs_key = f"observation.images.{key.replace('.image', '')}"
observation_frame[obs_key] = value
# 4. Preprocess and run policy
batch = preprocessor(policy_obs)
# Add state as numpy array
observation_frame["observation.state"] = np.array(ee_values, dtype=np.float32)
# Add batch dimension if needed
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dim() == 1:
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], torch.Tensor) and batch[key].dim() == 3:
batch[key] = batch[key].unsqueeze(0)
# 4. Run policy inference using predict_action
action_tensor = predict_action(
observation=observation_frame,
policy=policy,
device=device,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=task,
robot_type=robot.robot_type,
)
with torch.inference_mode():
action_tensor = policy.select_action(batch)
# 5. Postprocess and convert tensor to dict
action_tensor = postprocessor(action_tensor)
# Flatten to 1D: take first timestep if chunks, squeeze batch dims
while action_tensor.dim() > 1:
action_tensor = action_tensor[0]
# Convert tensor to dict using action names from dataset
action_names = dataset.features[ACTION]["names"]
ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)}
# 5. Convert action tensor to dict
ee_action = make_robot_action(action_tensor, dataset.features)
# 6. Convert relative action back to absolute if needed
if use_relative_actions:
@@ -370,11 +384,11 @@ def run_ee_inference_loop(
# 8. Send joint commands to robot
robot.send_action(joint_action)
# 9. Save frame to dataset
# 9. Save frame to dataset (save original robot obs + joint action)
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
action_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
frame = {**observation_frame, **action_frame, "task": task}
obs_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
act_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
frame = {**obs_frame, **act_frame, "task": task}
dataset.add_frame(frame)
# 10. Visualization
@@ -392,6 +406,8 @@ def run_ee_inference_loop(
sleep_time = dt - loop_duration
if sleep_time > 0:
precise_sleep(sleep_time)
timestamp = time.perf_counter() - start_time
print(f" Completed {step} steps")