feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)

* fix: expose a function explicitly building a frame for inference

* fix: first make dataset frame, then make ready for inference

* fix: reducing reliance on lerobot record for policy's ouptuts too

* fix: encapsulating squeezing out + device handling from predict action

* fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion)

* fix(policies): right utils signature + docstrings (#2198)

---------

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Francesco Capuano
2025-10-14 15:47:32 +02:00
committed by GitHub
parent bf6ac5e110
commit 723013c71b
3 changed files with 117 additions and 21 deletions
+2 -17
View File
@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.robots import Robot
@@ -102,17 +103,7 @@ def predict_action(
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)
# Compute the next action with the policy
@@ -121,12 +112,6 @@ def predict_action(
action = postprocessor(action)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action