mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 17:50:09 +00:00
make flow similar to evaluate.py
This commit is contained in:
@@ -42,13 +42,16 @@ from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_featur
|
|||||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.policies.utils import make_robot_action
|
||||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
|
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
|
||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
|
from lerobot.utils.control_utils import predict_action
|
||||||
from lerobot.utils.relative_actions import (
|
from lerobot.utils.relative_actions import (
|
||||||
convert_state_to_relative,
|
convert_state_to_relative,
|
||||||
convert_from_relative_actions,
|
convert_from_relative_actions,
|
||||||
PerTimestepNormalizer,
|
PerTimestepNormalizer,
|
||||||
)
|
)
|
||||||
|
from lerobot.utils.utils import get_safe_torch_device
|
||||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
robot_action_observation_to_transition,
|
robot_action_observation_to_transition,
|
||||||
@@ -278,7 +281,15 @@ def run_ee_inference_loop(
|
|||||||
display_data: bool = True,
|
display_data: bool = True,
|
||||||
):
|
):
|
||||||
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
|
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
|
||||||
|
device = get_safe_torch_device(policy.config.device)
|
||||||
|
|
||||||
|
# Reset policy and processors
|
||||||
|
policy.reset()
|
||||||
|
preprocessor.reset()
|
||||||
|
postprocessor.reset()
|
||||||
|
|
||||||
dt = 1.0 / fps
|
dt = 1.0 / fps
|
||||||
|
timestamp = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
step = 0
|
step = 0
|
||||||
|
|
||||||
@@ -289,12 +300,13 @@ def run_ee_inference_loop(
|
|||||||
mode_str += " [relative state]"
|
mode_str += " [relative state]"
|
||||||
print(f"\nRunning EE inference for {duration_s}s...{mode_str}")
|
print(f"\nRunning EE inference for {duration_s}s...{mode_str}")
|
||||||
|
|
||||||
while time.perf_counter() - start_time < duration_s:
|
while timestamp < duration_s:
|
||||||
if events.get("exit_early") or events.get("stop_recording"):
|
|
||||||
break
|
|
||||||
|
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if events.get("exit_early"):
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
# 1. Get robot observation (joint positions)
|
# 1. Get robot observation (joint positions)
|
||||||
robot_obs = robot.get_observation()
|
robot_obs = robot.get_observation()
|
||||||
|
|
||||||
@@ -309,44 +321,46 @@ def run_ee_inference_loop(
|
|||||||
# Store current EE position for relative action conversion
|
# Store current EE position for relative action conversion
|
||||||
current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())])
|
current_ee_pos = torch.tensor([ee_state.get(k, 0.0) for k in sorted(ee_state.keys())])
|
||||||
|
|
||||||
# 3. Build policy input with EE state
|
# 3. Build observation frame with EE state for policy input
|
||||||
|
# Get state names from dataset features to build state array
|
||||||
|
state_names = dataset.features.get("observation.state", {}).get("names", [])
|
||||||
|
|
||||||
|
# Build state array from EE values (sorted to match training order)
|
||||||
|
ee_keys = sorted(ee_state.keys())
|
||||||
|
ee_values = [ee_state[k] for k in ee_keys]
|
||||||
|
|
||||||
# Convert to relative state if enabled (UMI-style)
|
# Convert to relative state if enabled (UMI-style)
|
||||||
if use_relative_state:
|
if use_relative_state:
|
||||||
ee_state_tensor = torch.tensor([ee_state[k] for k in sorted(ee_state.keys())])
|
ee_state_tensor = torch.tensor(ee_values)
|
||||||
relative_state = convert_state_to_relative(ee_state_tensor.unsqueeze(0))
|
relative_state = convert_state_to_relative(ee_state_tensor.unsqueeze(0))
|
||||||
ee_state = {k: float(relative_state[0, i]) for i, k in enumerate(sorted(ee_state.keys()))}
|
ee_values = [float(relative_state[0, i]) for i in range(len(ee_values))]
|
||||||
|
|
||||||
policy_obs = {"observation.state": ee_state, "task": task}
|
# Build observation dict for policy (images + state as numpy arrays)
|
||||||
|
observation_frame = {}
|
||||||
|
|
||||||
# Add images
|
# Add images
|
||||||
for cam_name in robot.cameras:
|
for key, value in robot_obs.items():
|
||||||
img = robot_obs.get(f"{cam_name}.image")
|
if ".image" in key:
|
||||||
if img is not None:
|
obs_key = f"observation.images.{key.replace('.image', '')}"
|
||||||
policy_obs[f"observation.images.{cam_name}"] = img
|
observation_frame[obs_key] = value
|
||||||
|
|
||||||
# 4. Preprocess and run policy
|
# Add state as numpy array
|
||||||
batch = preprocessor(policy_obs)
|
observation_frame["observation.state"] = np.array(ee_values, dtype=np.float32)
|
||||||
|
|
||||||
# Add batch dimension if needed
|
# 4. Run policy inference using predict_action
|
||||||
for key in batch:
|
action_tensor = predict_action(
|
||||||
if isinstance(batch[key], torch.Tensor) and batch[key].dim() == 1:
|
observation=observation_frame,
|
||||||
batch[key] = batch[key].unsqueeze(0)
|
policy=policy,
|
||||||
elif isinstance(batch[key], torch.Tensor) and batch[key].dim() == 3:
|
device=device,
|
||||||
batch[key] = batch[key].unsqueeze(0)
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
use_amp=policy.config.use_amp,
|
||||||
|
task=task,
|
||||||
|
robot_type=robot.robot_type,
|
||||||
|
)
|
||||||
|
|
||||||
with torch.inference_mode():
|
# 5. Convert action tensor to dict
|
||||||
action_tensor = policy.select_action(batch)
|
ee_action = make_robot_action(action_tensor, dataset.features)
|
||||||
|
|
||||||
# 5. Postprocess and convert tensor to dict
|
|
||||||
action_tensor = postprocessor(action_tensor)
|
|
||||||
|
|
||||||
# Flatten to 1D: take first timestep if chunks, squeeze batch dims
|
|
||||||
while action_tensor.dim() > 1:
|
|
||||||
action_tensor = action_tensor[0]
|
|
||||||
|
|
||||||
# Convert tensor to dict using action names from dataset
|
|
||||||
action_names = dataset.features[ACTION]["names"]
|
|
||||||
ee_action = {name: float(action_tensor[i]) for i, name in enumerate(action_names)}
|
|
||||||
|
|
||||||
# 6. Convert relative action back to absolute if needed
|
# 6. Convert relative action back to absolute if needed
|
||||||
if use_relative_actions:
|
if use_relative_actions:
|
||||||
@@ -370,11 +384,11 @@ def run_ee_inference_loop(
|
|||||||
# 8. Send joint commands to robot
|
# 8. Send joint commands to robot
|
||||||
robot.send_action(joint_action)
|
robot.send_action(joint_action)
|
||||||
|
|
||||||
# 9. Save frame to dataset
|
# 9. Save frame to dataset (save original robot obs + joint action)
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
observation_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
|
obs_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
|
||||||
action_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
|
act_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
|
||||||
frame = {**observation_frame, **action_frame, "task": task}
|
frame = {**obs_frame, **act_frame, "task": task}
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
# 10. Visualization
|
# 10. Visualization
|
||||||
@@ -392,6 +406,8 @@ def run_ee_inference_loop(
|
|||||||
sleep_time = dt - loop_duration
|
sleep_time = dt - loop_duration
|
||||||
if sleep_time > 0:
|
if sleep_time > 0:
|
||||||
precise_sleep(sleep_time)
|
precise_sleep(sleep_time)
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
print(f" Completed {step} steps")
|
print(f" Completed {step} steps")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user