mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
feat(scripts): Introduce build_inference_frame/make_robot_action util to easily allow API-based Inference (#2143)
* fix: expose a function explicitly building a frame for inference * fix: first make dataset frame, then make ready for inference * fix: reducing reliance on lerobot record for policy's ouptuts too * fix: encapsulating squeezing out + device handling from predict action * fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion) * fix(policies): right utils signature + docstrings (#2198) --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
committed by
GitHub
parent
bf6ac5e110
commit
723013c71b
@@ -31,6 +31,7 @@ from deepdiff import DeepDiff
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.robots import Robot
|
||||
|
||||
@@ -102,17 +103,7 @@ def predict_action(
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = prepare_observation_for_inference(observation, device, task, robot_type)
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
@@ -121,12 +112,6 @@ def predict_action(
|
||||
|
||||
action = postprocessor(action)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user