mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-20 01:37:08 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b06ad40888 | |||
| b3d74f80f0 | |||
| 552b4c3563 | |||
| 8bf6056d14 | |||
| da92db8fc0 | |||
| 2b0834bcb8 | |||
| 287c823f13 | |||
| 58ccc01508 |
@@ -136,6 +136,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ from typing import Any
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from lerobot.datasets.io_utils import write_table_one_row_group_per_episode
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
@@ -274,12 +275,11 @@ class LanguageColumnsWriter:
|
||||
new_table = self._materialize_table(
|
||||
table, per_row_persistent, per_row_events, drop_old=self.drop_existing_subtask_index
|
||||
)
|
||||
# Atomic replace: write to a sibling tmp path and rename so a crash
|
||||
# mid-write can't leave a half-written shard that ``pq.read_table``
|
||||
# would then fail to open. ``Path.replace`` is atomic on POSIX +
|
||||
# Windows when source and target sit on the same filesystem.
|
||||
# Re-emit one row group per episode (a bulk pq.write_table would collapse
|
||||
# them into one). Write to a sibling tmp path and atomically rename so a
|
||||
# crash mid-write can't leave a half-written shard.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
pq.write_table(new_table, tmp_path)
|
||||
write_table_one_row_group_per_episode(new_table, tmp_path)
|
||||
tmp_path.replace(path)
|
||||
|
||||
def _materialize_table(
|
||||
|
||||
@@ -442,11 +442,12 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
|
||||
@@ -471,11 +471,12 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
|
||||
@@ -246,11 +246,12 @@ class ZMQCamera(Camera):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -180,24 +180,26 @@ class WandBLogger:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
batch_data = {}
|
||||
for k, v in d.items():
|
||||
# Skip the custom step key here, it's added to the batch below.
|
||||
if custom_step_key is not None and k == custom_step_key:
|
||||
continue
|
||||
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
batch_data[f"{mode}/{k}"] = v
|
||||
|
||||
if batch_data:
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||
self._wandb.log(batch_data)
|
||||
else:
|
||||
self._wandb.log(data=batch_data, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -32,6 +32,7 @@ from .feature_utils import features_equal_for_merge, get_hf_features_from_featur
|
||||
from .io_utils import (
|
||||
get_file_size_in_mb,
|
||||
get_parquet_file_size_in_mb,
|
||||
to_parquet_one_row_group_per_episode,
|
||||
to_parquet_with_hf_images,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -551,6 +552,7 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si
|
||||
aggr_root=dst_meta.root,
|
||||
hf_features=hf_features,
|
||||
concatenate=concatenate_data,
|
||||
one_row_group_per_episode=True,
|
||||
)
|
||||
|
||||
# Record the mapping from source to actual destination
|
||||
@@ -628,6 +630,7 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Path = None,
|
||||
hf_features: datasets.Features | None = None,
|
||||
concatenate: bool = True,
|
||||
one_row_group_per_episode: bool = False,
|
||||
) -> tuple[dict[str, int], tuple[int, int]]:
|
||||
"""Appends data to an existing parquet file or creates a new one based on size constraints.
|
||||
|
||||
@@ -645,6 +648,8 @@ def append_or_create_parquet_file(
|
||||
aggr_root: Root path for the aggregated dataset.
|
||||
hf_features: Optional HuggingFace Features schema for proper image typing.
|
||||
concatenate: When False, always rotate to a new file instead of appending to the current one.
|
||||
one_row_group_per_episode: True for DATA parquet (emit one row group per episode); False for
|
||||
the episodes-metadata parquet (already one row per episode).
|
||||
|
||||
Returns:
|
||||
tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict
|
||||
@@ -657,6 +662,8 @@ def append_or_create_parquet_file(
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(df, dst_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path)
|
||||
return idx, (dst_chunk, dst_file)
|
||||
@@ -683,6 +690,8 @@ def append_or_create_parquet_file(
|
||||
|
||||
if contains_images:
|
||||
to_parquet_with_hf_images(final_df, target_path, features=hf_features)
|
||||
elif one_row_group_per_episode:
|
||||
to_parquet_one_row_group_per_episode(final_df, target_path)
|
||||
else:
|
||||
final_df.to_parquet(target_path)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -74,6 +74,8 @@ class DatasetReader:
|
||||
self.episodes = episodes
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
@@ -86,6 +88,16 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self._image_transforms = None
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import datasets
|
||||
import numpy as np
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
@@ -153,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
@@ -270,21 +271,49 @@ def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[to
|
||||
return items_dict
|
||||
|
||||
|
||||
def write_table_one_row_group_per_episode(table: pa.Table, path: Path) -> None:
|
||||
"""Write ``table`` with one parquet row group per episode (in episode order).
|
||||
|
||||
Keeps shards random-access friendly (``read_row_group(i)`` fetches episode i),
|
||||
mirroring the recording writer. ``table`` must carry a contiguous
|
||||
``episode_index`` column.
|
||||
"""
|
||||
episode_index = table.column("episode_index").to_numpy(zero_copy_only=False)
|
||||
starts = np.concatenate(([0], np.nonzero(np.diff(episode_index))[0] + 1))
|
||||
writer = pq.ParquetWriter(str(path), table.schema, compression="snappy", use_dictionary=True)
|
||||
try:
|
||||
for start, stop in zip(starts, np.append(starts[1:], len(episode_index)), strict=True):
|
||||
writer.write_table(table.slice(start, stop - start)) # one episode -> one row group
|
||||
finally:
|
||||
writer.close()
|
||||
|
||||
|
||||
def to_parquet_with_hf_images(
|
||||
df: pandas.DataFrame, path: Path, features: datasets.Features | None = None
|
||||
) -> None:
|
||||
"""This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||
This way, it can be loaded by HF dataset and correctly formatted images are returned.
|
||||
"""Write a DataFrame with HF-encoded images to parquet, one row group per episode.
|
||||
|
||||
Args:
|
||||
df: DataFrame to write to parquet.
|
||||
path: Path to write the parquet file.
|
||||
features: Optional HuggingFace Features schema. If provided, ensures image columns
|
||||
are properly typed as Image() in the parquet schema.
|
||||
Images are embedded into the arrow table first (``ParquetWriter.write_table``
|
||||
does not embed external image files like ``Dataset.to_parquet`` does).
|
||||
``features`` types image columns as ``Image()`` in the parquet schema.
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features)
|
||||
ds.to_parquet(path)
|
||||
ds = embed_images(ds)
|
||||
table = ds.with_format("arrow")[:]
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
# No episode boundaries to align row groups to — keep a single write.
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def to_parquet_one_row_group_per_episode(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""Write a (non-image) DataFrame to parquet with one row group per episode."""
|
||||
table = pa.Table.from_pandas(df, preserve_index=False)
|
||||
if "episode_index" in table.column_names:
|
||||
write_table_one_row_group_per_episode(table, path)
|
||||
else:
|
||||
pq.write_table(table, str(path))
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
|
||||
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
# Load actual data
|
||||
if force_cache_sync or not self.reader.try_load():
|
||||
@@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._ensure_reader().set_image_transforms(image_transforms)
|
||||
self.image_transforms = image_transforms
|
||||
if self.reader is not None:
|
||||
self.reader._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self.set_image_transforms(None)
|
||||
if self.reader is not None:
|
||||
self.reader.set_image_transforms(None)
|
||||
self.image_transforms = None
|
||||
|
||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||
|
||||
|
||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if "camera_obs" in observations:
|
||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||
|
||||
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||
# so env plugins can expose extra observation keys via get_env_processors().
|
||||
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||
for key, value in observations.items():
|
||||
if key in _handled:
|
||||
continue
|
||||
target = f"{OBS_STR}.{key}"
|
||||
if target in return_observations:
|
||||
continue
|
||||
if isinstance(value, np.ndarray):
|
||||
val = torch.from_numpy(value).float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
elif isinstance(value, Tensor):
|
||||
val = value.float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -252,6 +252,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
pretrained_revision: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -309,6 +310,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -318,6 +320,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -557,6 +560,7 @@ def make_policy(
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
|
||||
@@ -124,6 +124,7 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
@@ -345,6 +345,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
|
||||
|
||||
This function uses `importlib.metadata` to find packages installed in the environment
|
||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
||||
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
prefixes = (
|
||||
"lerobot_robot_",
|
||||
"lerobot_camera_",
|
||||
"lerobot_teleoperator_",
|
||||
"lerobot_policy_",
|
||||
"lerobot_env_",
|
||||
)
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow.parquet as pq # noqa: E402
|
||||
|
||||
from lerobot.annotations.steerable_pipeline.reader import iter_episodes # noqa: E402
|
||||
@@ -344,6 +345,78 @@ def test_annotation_metadata_sync_allows_non_streaming_load(
|
||||
assert len(dataset) == 24
|
||||
|
||||
|
||||
def _build_packed_dataset(root: Path, episode_lengths: list[int], *, fps: int = 10) -> Path:
|
||||
"""Pack several episodes into a single shard (vs build_annotation_dataset's one-per-file),
|
||||
so the writer's rewrite must re-emit one row group per episode instead of collapsing them."""
|
||||
from lerobot.datasets.io_utils import write_tasks
|
||||
from lerobot.utils.io_utils import write_json
|
||||
|
||||
data_dir = root / "data" / "chunk-000"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episode_index, frame_index, timestamp, task_index, subtask_index = [], [], [], [], []
|
||||
for ep, length in enumerate(episode_lengths):
|
||||
episode_index += [ep] * length
|
||||
frame_index += list(range(length))
|
||||
timestamp += [round(i / fps, 6) for i in range(length)]
|
||||
task_index += [0] * length
|
||||
subtask_index += [0] * length # legacy column the writer must drop
|
||||
pd.DataFrame(
|
||||
{
|
||||
"episode_index": episode_index,
|
||||
"frame_index": frame_index,
|
||||
"timestamp": timestamp,
|
||||
"task_index": task_index,
|
||||
"subtask_index": subtask_index,
|
||||
}
|
||||
).to_parquet(data_dir / "file-000.parquet", index=False)
|
||||
|
||||
tasks_df = pd.DataFrame({"task_index": [0]}, index=pd.Index(["do the thing"], name="task"))
|
||||
write_tasks(tasks_df, root)
|
||||
write_json(
|
||||
{"codebase_version": "v3.1", "fps": fps, "features": {}, "total_episodes": len(episode_lengths)},
|
||||
root / "meta" / "info.json",
|
||||
)
|
||||
return root
|
||||
|
||||
|
||||
def test_writer_one_row_group_per_episode(tmp_path: Path) -> None:
|
||||
"""Rewriting a packed shard must keep one row group per episode, not collapse
|
||||
every episode into a single giant row group."""
|
||||
episode_lengths = [4, 6, 5] # unequal lengths, all in one shard
|
||||
root = _build_packed_dataset(tmp_path / "ds", episode_lengths)
|
||||
shard = root / "data" / "chunk-000" / "file-000.parquet"
|
||||
assert pq.ParquetFile(shard).metadata.num_row_groups == 1, "fixture should start collapsed"
|
||||
|
||||
staging_dir = tmp_path / "stage"
|
||||
for ep in range(len(episode_lengths)):
|
||||
_stage_episode(
|
||||
staging_dir,
|
||||
ep,
|
||||
plan=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"subtask for ep {ep}",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
records = list(iter_episodes(root))
|
||||
LanguageColumnsWriter().write_all(records, staging_dir, root)
|
||||
|
||||
# One row group per episode, with row counts matching the episode lengths.
|
||||
md = pq.ParquetFile(shard).metadata
|
||||
assert md.num_row_groups == len(episode_lengths)
|
||||
assert [md.row_group(i).num_rows for i in range(md.num_row_groups)] == episode_lengths
|
||||
# Language columns are still present after the per-episode rewrite.
|
||||
table = pq.read_table(shard)
|
||||
assert "language_persistent" in table.column_names
|
||||
assert "language_events" in table.column_names
|
||||
|
||||
|
||||
def test_speech_atom_shape_matches_plan_spec() -> None:
|
||||
atom = speech_atom(2.5, "I'm cleaning up!")
|
||||
assert atom["role"] == "assistant"
|
||||
|
||||
@@ -32,6 +32,26 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def assert_data_shards_one_row_group_per_episode(root):
|
||||
"""Every aggregated DATA shard must have exactly one parquet row group per episode."""
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
shards = sorted((root / "data").rglob("*.parquet"))
|
||||
assert shards, f"no data shards found under {root}/data"
|
||||
n_episodes = 0
|
||||
for shard in shards:
|
||||
pf = pq.ParquetFile(shard)
|
||||
episodes = pf.read(columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
assert pf.metadata.num_row_groups == len(set(episodes)), shard
|
||||
for i in range(pf.metadata.num_row_groups):
|
||||
rg_episodes = set(
|
||||
pf.read_row_group(i, columns=["episode_index"]).column("episode_index").to_pylist()
|
||||
)
|
||||
assert len(rg_episodes) == 1, f"{shard} row group {i} spans episodes {rg_episodes}"
|
||||
n_episodes += len(set(episodes))
|
||||
return n_episodes
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
assert aggr_ds.num_episodes == expected_episodes, (
|
||||
@@ -566,6 +586,41 @@ def assert_image_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_videos", [True, False], ids=["video", "image"])
|
||||
def test_aggregate_one_row_group_per_episode(tmp_path, lerobot_dataset_factory, use_videos):
|
||||
"""Aggregated DATA shards keep one row group per episode (not one collapsed group).
|
||||
|
||||
Covers both the non-image (``df.to_parquet``) and image
|
||||
(``to_parquet_with_hf_images``) write branches, including the merge-into-
|
||||
existing-file branch via a low file-size threshold that forces packing.
|
||||
"""
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_0",
|
||||
total_episodes=3,
|
||||
total_frames=60,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "rg_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_rg_1",
|
||||
total_episodes=4,
|
||||
total_frames=80,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
|
||||
aggr_root = tmp_path / "rg_aggr"
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_rg_aggr",
|
||||
aggr_root=aggr_root,
|
||||
)
|
||||
|
||||
n_episodes = assert_data_shards_one_row_group_per_episode(aggr_root)
|
||||
assert n_episodes == ds_0.num_episodes + ds_1.num_episodes
|
||||
|
||||
|
||||
def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation of image-based datasets preserves HuggingFace Image schema.
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
Reference in New Issue
Block a user