mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b5201f6c15 | |||
| 58ccc01508 |
@@ -0,0 +1,79 @@
|
||||
#!/usr/bin/env python
|
||||
"""Convert a legacy LeRobot checkpoint to the current processor-pipeline format.
|
||||
|
||||
Older hub checkpoints (e.g. ``lerobot/act_aloha_sim_insertion_human``) bake
|
||||
normalization stats into the model weights and do not ship
|
||||
``policy_preprocessor.json`` / ``policy_postprocessor.json``. Current ``main``
|
||||
loads those processor configs from the checkpoint, so eval/rollout fail with
|
||||
``FileNotFoundError: Could not find 'policy_preprocessor.json'``.
|
||||
|
||||
This script rebuilds the processors from the training dataset's stats and saves
|
||||
a pipeline-format checkpoint locally that ``lerobot-eval`` can consume directly.
|
||||
|
||||
Usage:
|
||||
python examples/onnx/convert_legacy_checkpoint.py \
|
||||
--policy-path=lerobot/act_aloha_sim_insertion_human \
|
||||
--dataset-repo-id=lerobot/aloha_sim_insertion_human \
|
||||
--output-dir=outputs/converted/act_aloha_sim_insertion_human
|
||||
|
||||
Then:
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/converted/act_aloha_sim_insertion_human \
|
||||
--env.type=aloha --env.task=AlohaInsertion-v0 \
|
||||
--eval.batch_size=10 --eval.n_episodes=50 \
|
||||
--eval.use_async_envs=false --policy.device=cuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.utils.constants import (
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--policy-path", required=True, help="Legacy checkpoint repo id or local dir")
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id",
|
||||
required=True,
|
||||
help="Training dataset repo id, used only for normalization stats",
|
||||
)
|
||||
parser.add_argument("--output-dir", required=True, help="Where to save the converted checkpoint")
|
||||
parser.add_argument("--device", default="cpu", help="Device for building the policy (cpu is fine)")
|
||||
args = parser.parse_args()
|
||||
|
||||
out = Path(args.output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[1/4] Loading dataset stats from '{args.dataset_repo_id}' (metadata only)...")
|
||||
ds_meta = LeRobotDatasetMetadata(args.dataset_repo_id)
|
||||
|
||||
print(f"[2/4] Loading policy weights from '{args.policy_path}'...")
|
||||
cfg = PreTrainedConfig.from_pretrained(args.policy_path)
|
||||
cfg.pretrained_path = args.policy_path
|
||||
cfg.device = args.device
|
||||
policy = make_policy(cfg, ds_meta=ds_meta)
|
||||
|
||||
print("[3/4] Building processors from dataset stats...")
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
dataset_stats=ds_meta.stats,
|
||||
)
|
||||
|
||||
print(f"[4/4] Saving pipeline-format checkpoint to '{out}'...")
|
||||
policy.save_pretrained(out)
|
||||
preprocessor.save_pretrained(out, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json")
|
||||
postprocessor.save_pretrained(out, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json")
|
||||
|
||||
print(f"\nDone. Converted checkpoint at: {out}")
|
||||
print("Eval it with --policy.path=" + str(out))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python
|
||||
"""Evaluate an ACT policy in sim with either the PyTorch or ONNX network.
|
||||
|
||||
The ONNX backend swaps only ``policy.model`` (ResNet + transformer + action head)
|
||||
with an onnxruntime session. Everything else - the LeRobot processor pipeline
|
||||
(normalization), the action queue, and the gym env - is identical, so any
|
||||
difference in success rate is attributable to the network backend alone.
|
||||
|
||||
Run both backends with the same seed to compare:
|
||||
|
||||
python examples/onnx/eval_act_onnx.py \
|
||||
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--task=AlohaTransferCube-v0 \
|
||||
--backend=torch --n-episodes=50 --batch-size=10 --device=cuda
|
||||
|
||||
python examples/onnx/eval_act_onnx.py \
|
||||
--policy-path=lerobot/act_aloha_sim_transfer_cube_human \
|
||||
--task=AlohaTransferCube-v0 \
|
||||
--onnx=outputs/onnx/act_transfer_cube.onnx \
|
||||
--backend=onnx --n-episodes=50 --batch-size=10 --device=cuda
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.envs.factory import make_env, make_env_config, make_env_pre_post_processors
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.factory import make_pre_post_processors
|
||||
from lerobot.scripts.lerobot_eval import eval_policy
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
class ONNXACTModel(nn.Module):
|
||||
"""Drop-in replacement for ``ACTPolicy.model`` backed by onnxruntime."""
|
||||
|
||||
def __init__(self, onnx_path: str, image_keys: list[str], has_state: bool, has_env_state: bool, device: str):
|
||||
super().__init__()
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = (
|
||||
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
if str(device).startswith("cuda")
|
||||
else ["CPUExecutionProvider"]
|
||||
)
|
||||
so = ort.SessionOptions()
|
||||
so.log_severity_level = 3
|
||||
self.sess = ort.InferenceSession(onnx_path, sess_options=so, providers=providers)
|
||||
self.image_keys = image_keys
|
||||
self.has_state = has_state
|
||||
self.has_env_state = has_env_state
|
||||
print(f"[onnx] providers in use: {self.sess.get_providers()}")
|
||||
|
||||
def forward(self, batch: dict):
|
||||
if self.has_state:
|
||||
state = batch[OBS_STATE]
|
||||
else:
|
||||
state = batch[OBS_ENV_STATE]
|
||||
ref = state
|
||||
ort_inputs = {"state": state.detach().cpu().numpy().astype(np.float32)}
|
||||
images = batch[OBS_IMAGES]
|
||||
for i, img in enumerate(images):
|
||||
ort_inputs[f"image_{i}"] = img.detach().cpu().numpy().astype(np.float32)
|
||||
out = self.sess.run(None, ort_inputs)[0]
|
||||
actions = torch.from_numpy(out).to(ref.device, dtype=ref.dtype)
|
||||
return actions, None
|
||||
|
||||
|
||||
def load_stats_from_checkpoint(policy_path: str, input_features, output_features) -> dict:
|
||||
"""Recover MEAN_STD stats baked into a legacy ACT checkpoint's safetensors buffers.
|
||||
|
||||
Legacy checkpoints store normalization as buffers like
|
||||
``normalize_inputs.buffer_observation_state.{mean,std}``. We map those back to
|
||||
feature names so we can rebuild the processor pipeline without the dataset.
|
||||
"""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
p = Path(policy_path)
|
||||
if p.is_dir():
|
||||
st_path = p / "model.safetensors"
|
||||
else:
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
st_path = Path(hf_hub_download(policy_path, "model.safetensors"))
|
||||
|
||||
sd = load_file(str(st_path))
|
||||
stats: dict = {}
|
||||
for feat in list(input_features) + list(output_features):
|
||||
buf = "buffer_" + feat.replace(".", "_")
|
||||
for prefix in ("normalize_inputs", "normalize_targets", "unnormalize_outputs"):
|
||||
mkey, skey = f"{prefix}.{buf}.mean", f"{prefix}.{buf}.std"
|
||||
if mkey in sd and skey in sd:
|
||||
stats[feat] = {"mean": sd[mkey].numpy(), "std": sd[skey].numpy()}
|
||||
break
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--policy-path", required=True)
|
||||
parser.add_argument("--task", required=True, help="e.g. AlohaTransferCube-v0")
|
||||
parser.add_argument("--env-type", default="aloha")
|
||||
parser.add_argument("--backend", choices=["torch", "onnx"], default="torch")
|
||||
parser.add_argument("--onnx", default=None, help="Path to .onnx (required for --backend=onnx)")
|
||||
parser.add_argument("--n-episodes", type=int, default=50)
|
||||
parser.add_argument("--batch-size", type=int, default=10)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--seed", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.backend == "onnx" and not args.onnx:
|
||||
raise SystemExit("--backend=onnx requires --onnx=<path>")
|
||||
|
||||
device = "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu"
|
||||
set_seed(args.seed)
|
||||
|
||||
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
|
||||
policy = ACTPolicy.from_pretrained(args.policy_path)
|
||||
policy.config.device = device
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
cfg = policy.config
|
||||
|
||||
if args.backend == "onnx":
|
||||
image_keys = list(cfg.image_features)
|
||||
has_state = cfg.robot_state_feature is not None
|
||||
has_env_state = cfg.env_state_feature is not None
|
||||
print(f"[2/4] Swapping policy.model with ONNX backend ({args.onnx})")
|
||||
policy.model = ONNXACTModel(args.onnx, image_keys, has_state, has_env_state, device)
|
||||
policy.to(device)
|
||||
else:
|
||||
print("[2/4] Using PyTorch backend")
|
||||
|
||||
print("[3/4] Building processors and environment...")
|
||||
stats = load_stats_from_checkpoint(args.policy_path, cfg.input_features, cfg.output_features)
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg,
|
||||
dataset_stats=stats,
|
||||
preprocessor_overrides={"device_processor": {"device": device}},
|
||||
)
|
||||
|
||||
env_cfg = make_env_config(args.env_type, task=args.task)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=env_cfg, policy_cfg=cfg)
|
||||
env_groups = make_env(env_cfg, n_envs=args.batch_size, use_async_envs=False)
|
||||
# make_env returns {task_group: {idx: VectorEnv}}; grab the single env.
|
||||
first_group = next(iter(env_groups.values()))
|
||||
env = next(iter(first_group.values()))
|
||||
|
||||
print(f"[4/4] Evaluating backend='{args.backend}' for {args.n_episodes} episodes (seed={args.seed})...")
|
||||
with torch.no_grad():
|
||||
info = eval_policy(
|
||||
env=env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=args.n_episodes,
|
||||
start_seed=args.seed,
|
||||
)
|
||||
|
||||
agg = info["aggregated"]
|
||||
print("\n==== RESULT ====")
|
||||
print(f"backend : {args.backend}")
|
||||
print(f"task : {args.task}")
|
||||
print(f"n_episodes : {args.n_episodes}")
|
||||
print(f"pc_success : {agg['pc_success']:.1f}%")
|
||||
print(f"avg_max_reward: {agg['avg_max_reward']:.4f}")
|
||||
print(f"eval_ep_s : {agg['eval_ep_s']:.2f}s")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python
|
||||
"""Export an ACT policy's network to ONNX and verify numerical parity.
|
||||
|
||||
Only the inference network is exported (ResNet backbone + transformer enc/dec +
|
||||
action head). The VAE encoder is training-only and the inference latent is zeros,
|
||||
so the exported graph is a pure function of (state, images) -> action_chunk.
|
||||
Normalization stays in the LeRobot processor pipeline (outside ONNX).
|
||||
|
||||
Usage:
|
||||
python examples/onnx/export_act.py \
|
||||
--policy-path=outputs/converted/act_aloha_sim_transfer_cube_human \
|
||||
--output=outputs/onnx/act_transfer_cube.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
|
||||
|
||||
class ACTExportWrapper(nn.Module):
|
||||
"""Tensor-in/tensor-out wrapper around ACT's inference network."""
|
||||
|
||||
def __init__(self, model: nn.Module, image_keys: list[str], has_state: bool, has_env_state: bool):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.image_keys = image_keys
|
||||
self.has_state = has_state
|
||||
self.has_env_state = has_env_state
|
||||
|
||||
def forward(self, state: torch.Tensor, *images: torch.Tensor) -> torch.Tensor:
|
||||
batch: dict = {}
|
||||
if self.has_state:
|
||||
batch[OBS_STATE] = state
|
||||
if self.has_env_state:
|
||||
# Convention: when env_state is used it is passed as `state`.
|
||||
batch[OBS_ENV_STATE] = state
|
||||
batch[OBS_IMAGES] = list(images)
|
||||
actions, _ = self.model(batch)
|
||||
return actions
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--policy-path", required=True, help="Converted ACT checkpoint dir or repo id")
|
||||
parser.add_argument("--output", required=True, help="Output .onnx path")
|
||||
parser.add_argument("--opset", type=int, default=17)
|
||||
parser.add_argument("--atol", type=float, default=1e-3)
|
||||
parser.add_argument("--device", default="cpu")
|
||||
args = parser.parse_args()
|
||||
|
||||
out = Path(args.output)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"[1/4] Loading ACT policy from '{args.policy_path}'...")
|
||||
policy = ACTPolicy.from_pretrained(args.policy_path)
|
||||
policy.eval()
|
||||
policy.to(args.device)
|
||||
cfg = policy.config
|
||||
|
||||
image_keys = list(cfg.image_features)
|
||||
has_state = cfg.robot_state_feature is not None
|
||||
has_env_state = cfg.env_state_feature is not None
|
||||
state_dim = (cfg.robot_state_feature or cfg.env_state_feature).shape[0]
|
||||
|
||||
print(f" image_keys={image_keys} state_dim={state_dim} "
|
||||
f"chunk_size={cfg.chunk_size} action_dim={cfg.action_feature.shape[0]}")
|
||||
|
||||
wrapper = ACTExportWrapper(policy.model, image_keys, has_state, has_env_state).eval().to(args.device)
|
||||
|
||||
# Build example inputs (batch size 1) from the config feature shapes.
|
||||
state_example = torch.randn(1, state_dim, device=args.device)
|
||||
image_examples = [
|
||||
torch.rand(1, *cfg.image_features[k].shape, device=args.device) for k in image_keys
|
||||
]
|
||||
example_inputs = (state_example, *image_examples)
|
||||
|
||||
input_names = ["state"] + [f"image_{i}" for i in range(len(image_keys))]
|
||||
output_names = ["action_chunk"]
|
||||
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
|
||||
|
||||
print(f"[2/4] Exporting to ONNX (opset {args.opset}) -> {out}")
|
||||
torch.onnx.export(
|
||||
wrapper,
|
||||
example_inputs,
|
||||
str(out),
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=args.opset,
|
||||
do_constant_folding=True,
|
||||
dynamo=False,
|
||||
)
|
||||
|
||||
print("[3/4] Running parity check (torch vs onnxruntime)...")
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = ["CPUExecutionProvider"]
|
||||
so = ort.SessionOptions()
|
||||
so.log_severity_level = 3
|
||||
sess = ort.InferenceSession(str(out), sess_options=so, providers=providers)
|
||||
|
||||
# Fresh random inputs for the check.
|
||||
state_check = torch.randn(2, state_dim, device=args.device)
|
||||
image_check = [torch.rand(2, *cfg.image_features[k].shape, device=args.device) for k in image_keys]
|
||||
|
||||
with torch.no_grad():
|
||||
torch_out = wrapper(state_check, *image_check).cpu().numpy()
|
||||
|
||||
ort_inputs = {"state": state_check.cpu().numpy()}
|
||||
for i, img in enumerate(image_check):
|
||||
ort_inputs[f"image_{i}"] = img.cpu().numpy()
|
||||
ort_out = sess.run(None, ort_inputs)[0]
|
||||
|
||||
max_abs = float(np.max(np.abs(torch_out - ort_out)))
|
||||
mean_abs = float(np.mean(np.abs(torch_out - ort_out)))
|
||||
print(f" shapes: torch={torch_out.shape} onnx={ort_out.shape}")
|
||||
print(f" max_abs_diff={max_abs:.3e} mean_abs_diff={mean_abs:.3e} (atol={args.atol:.0e})")
|
||||
|
||||
ok = max_abs <= args.atol
|
||||
print(f"[4/4] Parity: {'PASS' if ok else 'FAIL'}")
|
||||
if not ok:
|
||||
raise SystemExit(f"Parity check failed: max_abs_diff {max_abs:.3e} > atol {args.atol:.0e}")
|
||||
print(f"\nDone. ONNX model at: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -54,6 +54,7 @@ from typing import Any
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
||||
# them into one). Write to a sibling tmp path and atomically rename so a
|
||||
# crash mid-write can't leave a half-written shard.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
write_table_one_row_group_per_episode(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
|
||||
@@ -442,12 +442,11 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
|
||||
@@ -471,12 +471,11 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
|
||||
@@ -246,12 +246,11 @@ class ZMQCamera(Camera):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not stop_event.is_set():
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_one_row_group_per_episode,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
one_row_group_per_episode=True,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
one_row_group_per_episode: bool = False,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
||||
the episodes-metadata parquet (already one row per episode).
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, (dst_chunk, dst_file)
|
||||
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
return items_dict
|
||||
|
||||
|
||||
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||
"""Write ``table`` with one parquet row group per episode (in episode order).
|
||||
|
||||
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
||||
mirroring the recording writer. ``table`` must carry a contiguous
|
||||
``episode_index`` column.
|
||||
"""
|
||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||
try:
|
||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
||||
does not embed external image files like ``Dataset.to_parquet`` does).
|
||||
``features`` types image columns as ``Image()`` in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
ds = embed_images(ds)
|
||||
table = ds.with_format("arrow")[:]
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
# No episode boundaries to align row groups to — keep a single write.
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -28,6 +28,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
|
||||
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
|
||||
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
|
||||
for ep, length in enumerate(episode_lengths):
|
||||
episode_index += [ep] * length
|
||||
frame_index += list(range(length))
|
||||
timestamp += [round(i / fps, 6) for i in range(length)]
|
||||
task_index += [0] * length
|
||||
subtask_index += [0] * length # legacy column the writer must drop
|
||||
pd.DataFrame(
|
||||
{
|
||||
"episode_index": episode_index,
|
||||
"frame_index": frame_index,
|
||||
"timestamp": timestamp,
|
||||
"task_index": task_index,
|
||||
"subtask_index": subtask_index,
|
||||
}
|
||||
).to_parquet(data_dir / "file-000.parquet", index=False)
|
||||
|
||||
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
|
||||
write_tasks(tasks_df, root)
|
||||
write_json(
|
||||
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
|
||||
"""Rewriting a packed shard must keep one row group per episode, not collapse
|
||||
every episode into a single giant row group."""
|
||||
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
|
||||
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
|
||||
shard = root / "data" / "chunk-000" / "file-000.parquet"
|
||||
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
for ep in range(len(episode_lengths)):
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
ep,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"subtask for ep {ep}",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
records = list(iter_episodes(root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, root)
|
||||
|
||||
# One row group per episode, with row counts matching the episode lengths.
|
||||
md = pq.ParquetFile(shard).metadata
|
||||
assert md.num_row_groups == len(episode_lengths)
|
||||
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
|
||||
# Language columns are still present after the per-episode rewrite.
|
||||
table = pq.read_table(shard)
|
||||
assert "language_persistent" in table.column_names
|
||||
assert "language_events" in table.column_names
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
|
||||
@@ -32,6 +32,26 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def assert_data_shards_one_row_group_per_episode(root):
|
||||
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
shards = sorted((root / "data").rglob("*.parquet"))
|
||||
assert shards, f"no data shards found under {root}/data"
|
||||
n_episodes = 0
|
||||
for shard in shards:
|
||||
pf = pq.ParquetFile(shard)
|
||||
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
assert pf.metadata.num_row_groups == len(set(episodes)), shard
|
||||
for i in range(pf.metadata.num_row_groups):
|
||||
rg_episodes = set(
|
||||
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
)
|
||||
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
|
||||
n_episodes += len(set(episodes))
|
||||
return n_episodes
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
assert aggr_ds.num_episodes == expected_episodes, (
|
||||
@@ -566,6 +586,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
|
||||
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
|
||||
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
|
||||
|
||||
Covers both the non-image (``df.to_parquet``) and image
|
||||
(``to_parquet_with_hf_images``) write branches, including the merge-into-
|
||||
existing-file branch via a low file-size threshold that forces packing.
|
||||
"""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "rg_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
|
||||
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user