Compare commits

..

1 Commits

Author SHA1 Message Date
Martino Russi b5201f6c15 add onnx support 2026-06-16 15:15:48 +02:00
8 changed files with 391 additions and 10 deletions
@@ -0,0 +1,79 @@
#!/usr/bin/env python
"""Convert a legacy LeRobot checkpoint to the current processor-pipeline format.
Older hub checkpoints (e.g. ``lerobot/act_aloha_sim_insertion_human``) bake
normalization stats into the model weights and do not ship
``policy_preprocessor.json`` / ``policy_postprocessor.json``. Current ``main``
loads those processor configs from the checkpoint, so eval/rollout fail with
``FileNotFoundError: Could not find 'policy_preprocessor.json'``.
This script rebuilds the processors from the training dataset's stats and saves
a pipeline-format checkpoint locally that ``lerobot-eval`` can consume directly.
Usage:
python examples/onnx/convert_legacy_checkpoint.py \
--policy-path=lerobot/act_aloha_sim_insertion_human \
--dataset-repo-id=lerobot/aloha_sim_insertion_human \
--output-dir=outputs/converted/act_aloha_sim_insertion_human
Then:
lerobot-eval \
--policy.path=outputs/converted/act_aloha_sim_insertion_human \
--env.type=aloha --env.task=AlohaInsertion-v0 \
--eval.batch_size=10 --eval.n_episodes=50 \
--eval.use_async_envs=false --policy.device=cuda
"""
import argparse
from pathlib import Path
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_pre_post_processors
from lerobot.utils.constants import (
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True, help="Legacy checkpoint repo id or local dir")
parser.add_argument(
"--dataset-repo-id",
required=True,
help="Training dataset repo id, used only for normalization stats",
)
parser.add_argument("--output-dir", required=True, help="Where to save the converted checkpoint")
parser.add_argument("--device", default="cpu", help="Device for building the policy (cpu is fine)")
args = parser.parse_args()
out = Path(args.output_dir)
out.mkdir(parents=True, exist_ok=True)
print(f"[1/4] Loading dataset stats from '{args.dataset_repo_id}' (metadata only)...")
ds_meta = LeRobotDatasetMetadata(args.dataset_repo_id)
print(f"[2/4] Loading policy weights from '{args.policy_path}'...")
cfg = PreTrainedConfig.from_pretrained(args.policy_path)
cfg.pretrained_path = args.policy_path
cfg.device = args.device
policy = make_policy(cfg, ds_meta=ds_meta)
print("[3/4] Building processors from dataset stats...")
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=policy.config,
dataset_stats=ds_meta.stats,
)
print(f"[4/4] Saving pipeline-format checkpoint to '{out}'...")
policy.save_pretrained(out)
preprocessor.save_pretrained(out, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json")
postprocessor.save_pretrained(out, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json")
print(f"\nDone. Converted checkpoint at: {out}")
print("Eval it with --policy.path=" + str(out))
if __name__ == "__main__":
main()
+179
View File
@@ -0,0 +1,179 @@
#!/usr/bin/env python
"""Evaluate an ACT policy in sim with either the PyTorch or ONNX network.
The ONNX backend swaps only ``policy.model`` (ResNet + transformer + action head)
with an onnxruntime session. Everything else - the LeRobot processor pipeline
(normalization), the action queue, and the gym env - is identical, so any
difference in success rate is attributable to the network backend alone.
Run both backends with the same seed to compare:
python examples/onnx/eval_act_onnx.py \
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
--task=AlohaTransferCube-v0 \
--backend=torch --n-episodes=50 --batch-size=10 --device=cuda
python examples/onnx/eval_act_onnx.py \
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
--task=AlohaTransferCube-v0 \
--onnx=outputs/onnx/act_transfer_cube.onnx \
--backend=onnx --n-episodes=50 --batch-size=10 --device=cuda
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from torch import nn
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.scripts.lerobot_eval import eval_policy
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import set_seed
class ONNXACTModel(nn.Module):
"""Drop-in replacement for ``ACTPolicy.model`` backed by onnxruntime."""
def __init__(self, onnx_path: str, image_keys: list[str], has_state: bool, has_env_state: bool, device: str):
super().__init__()
import onnxruntime as ort
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"]
if str(device).startswith("cuda")
else ["CPUExecutionProvider"]
)
so = ort.SessionOptions()
so.log_severity_level = 3
self.sess = ort.InferenceSession(onnx_path, sess_options=so, providers=providers)
self.image_keys = image_keys
self.has_state = has_state
self.has_env_state = has_env_state
print(f"[onnx] providers in use: {self.sess.get_providers()}")
def forward(self, batch: dict):
if self.has_state:
state = batch[OBS_STATE]
else:
state = batch[OBS_ENV_STATE]
ref = state
ort_inputs = {"state": state.detach().cpu().numpy().astype(np.float32)}
images = batch[OBS_IMAGES]
for i, img in enumerate(images):
ort_inputs[f"image_{i}"] = img.detach().cpu().numpy().astype(np.float32)
out = self.sess.run(None, ort_inputs)[0]
actions = torch.from_numpy(out).to(ref.device, dtype=ref.dtype)
return actions, None
def load_stats_from_checkpoint(policy_path: str, input_features, output_features) -> dict:
"""Recover MEAN_STD stats baked into a legacy ACT checkpoint's safetensors buffers.
Legacy checkpoints store normalization as buffers like
``normalize_inputs.buffer_observation_state.{mean,std}``. We map those back to
feature names so we can rebuild the processor pipeline without the dataset.
"""
from safetensors.torch import load_file
p = Path(policy_path)
if p.is_dir():
st_path = p / "model.safetensors"
else:
from huggingface_hub import hf_hub_download
st_path = Path(hf_hub_download(policy_path, "model.safetensors"))
sd = load_file(str(st_path))
stats: dict = {}
for feat in list(input_features) + list(output_features):
buf = "buffer_" + feat.replace(".", "_")
for prefix in ("normalize_inputs", "normalize_targets", "unnormalize_outputs"):
mkey, skey = f"{prefix}.{buf}.mean", f"{prefix}.{buf}.std"
if mkey in sd and skey in sd:
stats[feat] = {"mean": sd[mkey].numpy(), "std": sd[skey].numpy()}
break
return stats
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True)
parser.add_argument("--task", required=True, help="e.g. AlohaTransferCube-v0")
parser.add_argument("--env-type", default="aloha")
parser.add_argument("--backend", choices=["torch", "onnx"], default="torch")
parser.add_argument("--onnx", default=None, help="Path to .onnx (required for --backend=onnx)")
parser.add_argument("--n-episodes", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=10)
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", type=int, default=1000)
args = parser.parse_args()
if args.backend == "onnx" and not args.onnx:
raise SystemExit("--backend=onnx requires --onnx=<path>")
device = "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
set_seed(args.seed)
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
policy = ACTPolicy.from_pretrained(args.policy_path)
policy.config.device = device
policy.eval()
policy.to(device)
cfg = policy.config
if args.backend == "onnx":
image_keys = list(cfg.image_features)
has_state = cfg.robot_state_feature is not None
has_env_state = cfg.env_state_feature is not None
print(f"[2/4] Swapping policy.model with ONNX backend ({args.onnx})")
policy.model = ONNXACTModel(args.onnx, image_keys, has_state, has_env_state, device)
policy.to(device)
else:
print("[2/4] Using PyTorch backend")
print("[3/4] Building processors and environment...")
stats = load_stats_from_checkpoint(args.policy_path, cfg.input_features, cfg.output_features)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg,
dataset_stats=stats,
preprocessor_overrides={"device_processor": {"device": device}},
)
env_cfg = make_env_config(args.env_type, task=args.task)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=env_cfg, policy_cfg=cfg)
env_groups = make_env(env_cfg, n_envs=args.batch_size, use_async_envs=False)
# make_env returns {task_group: {idx: VectorEnv}}; grab the single env.
first_group = next(iter(env_groups.values()))
env = next(iter(first_group.values()))
print(f"[4/4] Evaluating backend='{args.backend}' for {args.n_episodes} episodes (seed={args.seed})...")
with torch.no_grad():
info = eval_policy(
env=env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=args.n_episodes,
start_seed=args.seed,
)
agg = info["aggregated"]
print("\n==== RESULT ====")
print(f"backend : {args.backend}")
print(f"task : {args.task}")
print(f"n_episodes : {args.n_episodes}")
print(f"pc_success : {agg['pc_success']:.1f}%")
print(f"avg_max_reward: {agg['avg_max_reward']:.4f}")
print(f"eval_ep_s : {agg['eval_ep_s']:.2f}s")
env.close()
if __name__ == "__main__":
main()
+133
View File
@@ -0,0 +1,133 @@
#!/usr/bin/env python
"""Export an ACT policy's network to ONNX and verify numerical parity.
Only the inference network is exported (ResNet backbone + transformer enc/dec +
action head). The VAE encoder is training-only and the inference latent is zeros,
so the exported graph is a pure function of (state, images) -> action_chunk.
Normalization stays in the LeRobot processor pipeline (outside ONNX).
Usage:
python examples/onnx/export_act.py \
--policy-path=outputs/converted/act_aloha_sim_transfer_cube_human \
--output=outputs/onnx/act_transfer_cube.onnx
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from torch import nn
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
class ACTExportWrapper(nn.Module):
"""Tensor-in/tensor-out wrapper around ACT's inference network."""
def __init__(self, model: nn.Module, image_keys: list[str], has_state: bool, has_env_state: bool):
super().__init__()
self.model = model
self.image_keys = image_keys
self.has_state = has_state
self.has_env_state = has_env_state
def forward(self, state: torch.Tensor, *images: torch.Tensor) -> torch.Tensor:
batch: dict = {}
if self.has_state:
batch[OBS_STATE] = state
if self.has_env_state:
# Convention: when env_state is used it is passed as `state`.
batch[OBS_ENV_STATE] = state
batch[OBS_IMAGES] = list(images)
actions, _ = self.model(batch)
return actions
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--policy-path", required=True, help="Converted ACT checkpoint dir or repo id")
parser.add_argument("--output", required=True, help="Output .onnx path")
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--atol", type=float, default=1e-3)
parser.add_argument("--device", default="cpu")
args = parser.parse_args()
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
policy = ACTPolicy.from_pretrained(args.policy_path)
policy.eval()
policy.to(args.device)
cfg = policy.config
image_keys = list(cfg.image_features)
has_state = cfg.robot_state_feature is not None
has_env_state = cfg.env_state_feature is not None
state_dim = (cfg.robot_state_feature or cfg.env_state_feature).shape[0]
print(f" image_keys={image_keys} state_dim={state_dim} "
f"chunk_size={cfg.chunk_size} action_dim={cfg.action_feature.shape[0]}")
wrapper = ACTExportWrapper(policy.model, image_keys, has_state, has_env_state).eval().to(args.device)
# Build example inputs (batch size 1) from the config feature shapes.
state_example = torch.randn(1, state_dim, device=args.device)
image_examples = [
torch.rand(1, *cfg.image_features[k].shape, device=args.device) for k in image_keys
]
example_inputs = (state_example, *image_examples)
input_names = ["state"] + [f"image_{i}" for i in range(len(image_keys))]
output_names = ["action_chunk"]
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
print(f"[2/4] Exporting to ONNX (opset {args.opset}) -> {out}")
torch.onnx.export(
wrapper,
example_inputs,
str(out),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=args.opset,
do_constant_folding=True,
dynamo=False,
)
print("[3/4] Running parity check (torch vs onnxruntime)...")
import onnxruntime as ort
providers = ["CPUExecutionProvider"]
so = ort.SessionOptions()
so.log_severity_level = 3
sess = ort.InferenceSession(str(out), sess_options=so, providers=providers)
# Fresh random inputs for the check.
state_check = torch.randn(2, state_dim, device=args.device)
image_check = [torch.rand(2, *cfg.image_features[k].shape, device=args.device) for k in image_keys]
with torch.no_grad():
torch_out = wrapper(state_check, *image_check).cpu().numpy()
ort_inputs = {"state": state_check.cpu().numpy()}
for i, img in enumerate(image_check):
ort_inputs[f"image_{i}"] = img.cpu().numpy()
ort_out = sess.run(None, ort_inputs)[0]
max_abs = float(np.max(np.abs(torch_out - ort_out)))
mean_abs = float(np.mean(np.abs(torch_out - ort_out)))
print(f" shapes: torch={torch_out.shape} onnx={ort_out.shape}")
print(f" max_abs_diff={max_abs:.3e} mean_abs_diff={mean_abs:.3e} (atol={args.atol:.0e})")
ok = max_abs <= args.atol
print(f"[4/4] Parity: {'PASS' if ok else 'FAIL'}")
if not ok:
raise SystemExit(f"Parity check failed: max_abs_diff {max_abs:.3e} > atol {args.atol:.0e}")
print(f"\nDone. ONNX model at: {out}")
if __name__ == "__main__":
main()
-2
View File
@@ -79,8 +79,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: Path | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
pretrained_revision: str | None = None
def __post_init__(self) -> None:
if not self.device or not is_torch_device_available(self.device):
-2
View File
@@ -56,8 +56,6 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
device: str | None = None
pretrained_path: str | None = None
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
pretrained_revision: str | None = None
push_to_hub: bool = False
repo_id: str | None = None
-4
View File
@@ -252,7 +252,6 @@ class ProcessorConfigKwargs(TypedDict, total=False):
def make_pre_post_processors(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
pretrained_revision: str | None = None,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
@@ -310,7 +309,6 @@ def make_pre_post_processors(
overrides=kwargs.get("preprocessor_overrides", {}),
to_transition=batch_to_transition,
to_output=transition_to_batch,
revision=pretrained_revision,
)
postprocessor = PolicyProcessorPipeline.from_pretrained(
pretrained_model_name_or_path=pretrained_path,
@@ -320,7 +318,6 @@ def make_pre_post_processors(
overrides=kwargs.get("postprocessor_overrides", {}),
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
revision=pretrained_revision,
)
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
return preprocessor, postprocessor
@@ -560,7 +557,6 @@ def make_policy(
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
policy = policy_cls.from_pretrained(**kwargs)
elif cfg.pretrained_path and cfg.use_peft:
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
-1
View File
@@ -124,7 +124,6 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
if cfg.pretrained_path:
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
kwargs["revision"] = cfg.pretrained_revision
reward_model = reward_cls.from_pretrained(**kwargs)
else:
reward_model = reward_cls(**kwargs)
-1
View File
@@ -345,7 +345,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
**processor_kwargs,
)