diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 21f600ab1..809789ce2 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -5,9 +5,9 @@ from pathlib import Path import pandas as pd import tqdm -from lerobot.common.datasets.compute_stats import aggregate_stats -from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.common.datasets.utils import ( +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, @@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import ( concat_video_files, get_parquet_file_size_in_mb, get_video_size_in_mb, + to_parquet_with_hf_images, update_chunk_file_indices, write_info, write_stats, @@ -98,11 +99,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] ] ) fps, robot_type, features = validate_all_metadata(all_metadata) -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py - video_keys = [k for k, v in features.items() if v["dtype"] == "video"] -======= video_keys = [key for key in features if features[key]["dtype"] == "video"] ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py # Initialize output dataset metadata dst_meta = LeRobotDatasetMetadata.create( @@ -125,16 +122,14 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys } + dst_meta.episodes = {} + # Process each dataset for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx) data_idx = aggregate_data(src_meta, dst_meta, data_idx) -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py - meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys) -======= meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -210,15 +205,9 @@ def aggregate_videos(src_meta, dst_meta, videos_idx): file_idx, ) -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py - # Update the video index tracking - video_idx["chunk_idx"] = chunk_idx - video_idx["file_idx"] = file_idx -======= # Update the videos_idx with the final chunk and file indices for this key videos_idx[key]["chunk"] = chunk_idx videos_idx[key]["file"] = file_idx ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py return videos_idx @@ -237,22 +226,15 @@ def aggregate_data(src_meta, dst_meta, data_idx): df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - dst_path = aggr_root / DEFAULT_DATA_PATH.format( - chunk_index=data_idx["chunk"], file_index=data_idx["file"] - ) - data_idx = write_parquet_safely( + data_idx = append_or_create_parquet_file( df, src_path, - dst_path, data_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE, DEFAULT_DATA_PATH, -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py -======= contains_images=len(dst_meta.image_keys) > 0, aggr_root=dst_meta.root, ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py ) return data_idx @@ -282,17 +264,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): for k in videos_idx: videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py - dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format( - chunk_index=meta_idx["chunk"], file_index=meta_idx["file"] - ) - write_parquet_safely( -======= meta_idx = append_or_create_parquet_file( ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py df, src_path, - dst_path, meta_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE, @@ -304,19 +278,15 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): return meta_idx -def write_parquet_safely( +def append_or_create_parquet_file( df: pd.DataFrame, src_path: Path, - dst_path: Path, idx: dict[str, int], max_mb: float, chunk_size: int, default_path: str, -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py -======= contains_images: bool = False, aggr_root: Path = None, ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py ): """ Safely appends or creates a Parquet file at dst_path based on size constraints. @@ -324,7 +294,6 @@ def write_parquet_safely( Parameters: df (pd.DataFrame): Data to write. src_path (Path): Path to source file (used to get size). - dst_path (Path): Target path for writing. idx (dict): Dictionary containing 'chunk' and 'file' indices. max_mb (float): Maximum allowed file size in MB. chunk_size (int): Maximum number of files per chunk. @@ -333,11 +302,8 @@ def write_parquet_safely( Returns: dict: Updated index dictionary. """ -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py -======= # Initial destination path - use the correct default_path parameter dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py # If destination file doesn't exist, just write the new one if not dst_path.exists(): @@ -357,14 +323,6 @@ def write_parquet_safely( idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) new_path.parent.mkdir(parents=True, exist_ok=True) -<<<<<<< HEAD:src/lerobot/datasets/aggregate.py - df.to_parquet(new_path) - else: - # Append to existing file - existing_df = pd.read_parquet(dst_path) - combined_df = pd.concat([existing_df, df], ignore_index=True) - combined_df.to_parquet(dst_path) -======= final_df = df target_path = new_path else: @@ -377,7 +335,6 @@ def write_parquet_safely( to_parquet_with_hf_images(final_df, target_path) else: final_df.to_parquet(target_path) ->>>>>>> aa07b858 (fix: debug aggregation code):lerobot/common/datasets/aggregate.py return idx diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index e69de29bb..b4777582d 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -0,0 +1,1330 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import logging +import shutil +import tempfile +from pathlib import Path +from typing import Callable + +import datasets +import numpy as np +import packaging.version +import pandas as pd +import PIL.Image +import torch +import torch.utils +from datasets import Dataset +from huggingface_hub import HfApi, snapshot_download +from huggingface_hub.constants import REPOCARD_NAME +from huggingface_hub.errors import RevisionNotFoundError + +from lerobot.constants import HF_LEROBOT_HOME +from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats +from lerobot.datasets.image_writer import AsyncImageWriter, write_image +from lerobot.datasets.utils import ( + DEFAULT_EPISODES_PATH, + DEFAULT_FEATURES, + DEFAULT_IMAGE_PATH, + INFO_PATH, + _validate_feature_names, + check_delta_timestamps, + check_version_compatibility, + concat_video_files, + create_empty_dataset_info, + create_lerobot_dataset_card, + embed_images, + flatten_dict, + get_delta_indices, + get_hf_dataset_size_in_mb, + get_hf_features_from_features, + get_parquet_file_size_in_mb, + get_parquet_num_frames, + get_safe_version, + get_video_duration_in_s, + get_video_size_in_mb, + hf_transform_to_torch, + is_valid_version, + load_episodes, + load_info, + load_nested_dataset, + load_stats, + load_tasks, + to_parquet_with_hf_images, + update_chunk_file_indices, + validate_episode_buffer, + validate_frame, + write_info, + write_json, + write_stats, + write_tasks, +) +from lerobot.datasets.video_utils import ( + VideoFrame, + decode_video_frames, + encode_video_frames, + get_safe_default_codec, + get_video_info, +) + +CODEBASE_VERSION = "v3.0" + + +class LeRobotDatasetMetadata: + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + revision: str | None = None, + force_cache_sync: bool = False, + ): + self.repo_id = repo_id + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + # TODO(rcadene): instead of downloading all episodes metadata files, + # download only the ones associated to the requested episodes. This would + # require adding `episodes: list[int]` as argument. + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def load_metadata(self): + self.info = load_info(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks = load_tasks(self.root) + self.episodes = load_episodes(self.root) + self.stats = load_stats(self.root) + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + @property + def _version(self) -> packaging.version.Version: + """Codebase version used to create this dataset.""" + return packaging.version.parse(self.info["codebase_version"]) + + def get_data_file_path(self, ep_index: int) -> Path: + ep = self.episodes[ep_index] + chunk_idx = ep["data/chunk_index"] + file_idx = ep["data/file_index"] + fpath = self.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: + ep = self.episodes[ep_index] + chunk_idx = ep[f"videos/{vid_key}/chunk_index"] + file_idx = ep[f"videos/{vid_key}/file_index"] + fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) + return Path(fpath) + + @property + def data_path(self) -> str: + """Formattable string for the parquet files.""" + return self.info["data_path"] + + @property + def video_path(self) -> str | None: + """Formattable string for the video files.""" + return self.info["video_path"] + + @property + def robot_type(self) -> str | None: + """Robot type used in recording this dataset.""" + return self.info["robot_type"] + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.info["fps"] + + @property + def features(self) -> dict[str, dict]: + """All features contained in the dataset.""" + return self.info["features"] + + @property + def image_keys(self) -> list[str]: + """Keys to access visual modalities stored as images.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "image"] + + @property + def video_keys(self) -> list[str]: + """Keys to access visual modalities stored as videos.""" + return [key for key, ft in self.features.items() if ft["dtype"] == "video"] + + @property + def camera_keys(self) -> list[str]: + """Keys to access visual modalities (regardless of their storage method).""" + return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + + @property + def names(self) -> dict[str, list | dict]: + """Names of the various dimensions of vector modalities.""" + return {key: ft["names"] for key, ft in self.features.items()} + + @property + def shapes(self) -> dict: + """Shapes for the different features.""" + return {key: tuple(ft["shape"]) for key, ft in self.features.items()} + + @property + def total_episodes(self) -> int: + """Total number of episodes available.""" + return self.info["total_episodes"] + + @property + def total_frames(self) -> int: + """Total number of frames saved in this dataset.""" + return self.info["total_frames"] + + @property + def total_tasks(self) -> int: + """Total number of different tasks performed in this dataset.""" + return self.info["total_tasks"] + + @property + def chunks_size(self) -> int: + """Max number of files per chunk.""" + return self.info["chunks_size"] + + @property + def data_files_size_in_mb(self) -> int: + """Max size of data file in mega bytes.""" + return self.info["data_files_size_in_mb"] + + @property + def video_files_size_in_mb(self) -> int: + """Max size of video file in mega bytes.""" + return self.info["video_files_size_in_mb"] + + def get_task_index(self, task: str) -> int | None: + """ + Given a task in natural language, returns its task_index if the task already exists in the dataset, + otherwise return None. + """ + if task in self.tasks.index: + return int(self.tasks.loc[task].task_index) + else: + return None + + def save_episode_tasks(self, tasks: list[str]): + if len(set(tasks)) != len(tasks): + raise ValueError(f"Tasks are not unique: {tasks}") + + if self.tasks is None: + new_tasks = tasks + task_indices = range(len(tasks)) + self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks) + else: + new_tasks = [task for task in tasks if task not in self.tasks.index] + new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks)) + for task_idx, task in zip(new_task_indices, new_tasks, strict=False): + self.tasks.loc[task] = task_idx + + if len(new_tasks) > 0: + # Update on disk + write_tasks(self.tasks, self.root) + + def _save_episode_metadata(self, episode_dict: dict) -> None: + """Save episode metadata to a parquet file and update the Hugging Face dataset of episodes metadata. + + This function processes episodes metadata from a dictionary, converts it into a Hugging Face dataset, + and saves it as a parquet file. It handles both the creation of new parquet files and the + updating of existing ones based on size constraints. After saving the metadata, it reloads + the Hugging Face dataset to ensure it is up-to-date. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert buffer into HF Dataset + episode_dict = {key: [value] for key, value in episode_dict.items()} + ep_dataset = Dataset.from_dict(episode_dict) + ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) + df = pd.DataFrame(ep_dataset) + num_frames = episode_dict["length"][0] + + if self.episodes is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + df["meta/episodes/chunk_index"] = [chunk_idx] + df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [0] + df["dataset_to_index"] = [num_frames] + else: + # Retrieve information from the latest parquet file + latest_ep = self.episodes[-1] + chunk_idx = latest_ep["meta/episodes/chunk_index"] + file_idx = latest_ep["meta/episodes/file_index"] + + latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + + if latest_size_in_mb + ep_size_in_mb >= self.data_files_size_in_mb: + # Size limit is reached, prepare new parquet file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.chunks_size) + + # Update the existing pandas dataframe with new row + df["meta/episodes/chunk_index"] = [chunk_idx] + df["meta/episodes/file_index"] = [file_idx] + df["dataset_from_index"] = [latest_ep["dataset_to_index"]] + df["dataset_to_index"] = [latest_ep["dataset_to_index"] + num_frames] + + if latest_size_in_mb + ep_size_in_mb < self.data_files_size_in_mb: + # Size limit wasnt reached, concatenate latest dataframe with new one + latest_df = pd.read_parquet(latest_path) + df = pd.concat([latest_df, df], ignore_index=True) + + # Write the resulting dataframe from RAM to disk + path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(path, index=False) + + # Update the Hugging Face dataset by reloading it. + # This process should be fast because only the latest Parquet file has been modified. + # Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache. + self.episodes = load_episodes(self.root) + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + episode_metadata: dict, + ) -> None: + episode_dict = { + "episode_index": episode_index, + "tasks": episode_tasks, + "length": episode_length, + } + episode_dict.update(episode_metadata) + episode_dict.update(flatten_dict({"stats": episode_stats})) + self._save_episode_metadata(episode_dict) + + # Update info + self.info["total_episodes"] += 1 + self.info["total_frames"] += episode_length + self.info["total_tasks"] = len(self.tasks) + self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} + if len(self.video_keys) > 0: + self.update_video_info() + write_info(self.info, self.root) + + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats + write_stats(self.stats, self.root) + + def update_video_info(self) -> None: + """ + Warning: this function writes info from first episode videos, implicitly assuming that all videos have + been encoded the same way. Also, this means it assumes the first episode exists. + """ + for key in self.video_keys: + if not self.features[key].get("info", None): + video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + self.info["features"][key]["info"] = get_video_info(video_path) + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Total episodes: '{self.total_episodes}',\n" + f" Total frames: '{self.total_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + robot_type: str | None = None, + root: str | Path | None = None, + use_videos: bool = True, + ) -> "LeRobotDatasetMetadata": + """Creates metadata for a LeRobotDataset.""" + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id + + obj.root.mkdir(parents=True, exist_ok=False) + + features = {**features, **DEFAULT_FEATURES} + _validate_feature_names(features) + + obj.tasks = None + obj.episodes = None + obj.stats = None + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + if len(obj.video_keys) > 0 and not use_videos: + raise ValueError() + write_json(obj.info, obj.root / INFO_PATH) + obj.revision = None + return obj + + +class LeRobotDataset(torch.utils.data.Dataset): + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + download_videos: bool = True, + video_backend: str | None = None, + ): + """ + 2 modes are available for instantiating this class, depending on 2 different use cases: + + 1. Your dataset already exists: + - On your local disk in the 'root' folder. This is typically the case when you recorded your + dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class + with 'root' will load your dataset directly from disk. This can happen while you're offline (no + internet connection). + + - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on + your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download + the dataset from that address and load it, pending your dataset is compliant with + codebase_version v2.0. If your dataset has been created before this new format, you will be + prompted to convert it using our conversion script from v1.6 to v2.0, which you can find at + lerobot/datasets/v2/convert_dataset_v1_to_v2.py. + + + 2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty + LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an + existing dataset to the LeRobotDataset format. + + + In terms of files, LeRobotDataset encapsulates 3 main things: + - metadata: + - info contains various information about the dataset like shapes, keys, fps etc. + - stats stores the dataset statistics of the different modalities for normalization + - tasks contains the prompts for each task of the dataset, which can be used for + task-conditioned training. + - hf_dataset (from datasets.Dataset), which will read any values from parquet files. + - videos (optional) from which frames are loaded to be synchronous with data from parquet files. + + A typical LeRobotDataset looks like this from its root path: + . + ├── data + │ ├── chunk-000 + │ │ ├── episode_000000.parquet + │ │ ├── episode_000001.parquet + │ │ ├── episode_000002.parquet + │ │ └── ... + │ ├── chunk-001 + │ │ ├── episode_001000.parquet + │ │ ├── episode_001001.parquet + │ │ ├── episode_001002.parquet + │ │ └── ... + │ └── ... + ├── meta + │ ├── episodes.jsonl + │ ├── info.json + │ ├── stats.json + │ └── tasks.jsonl + └── videos + ├── chunk-000 + │ ├── observation.images.laptop + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + │ ├── observation.images.phone + │ │ ├── episode_000000.mp4 + │ │ ├── episode_000001.mp4 + │ │ ├── episode_000002.mp4 + │ │ └── ... + ├── chunk-001 + └── ... + + Note that this file-based structure is designed to be as versatile as possible. The files are split by + episodes which allows a more granular control over which episodes one wants to use and download. The + structure of the dataset is entirely described in the info.json file, which can be easily downloaded + or viewed directly on the hub before downloading any actual data. The type of files used are very + simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md + for the README). + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset + will be stored under root/repo_id. + root (Path | None, optional): Local directory to use for downloading/writing files. You can also + set the LEROBOT_HOME environment variable to point to a different location. Defaults to + '~/.cache/huggingface/lerobot'. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. Defaults to None. + image_transforms (Callable | None, optional): You can pass standard v2 image transforms from + torchvision.transforms.v2 here which will be applied to visual modalities (whether they come + from videos or images). Defaults to None. + delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None. + tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in + sync with the fps value. It is used at the init of the dataset to make sure that each + timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames + decoded from video files. It is also used to check that `delta_timestamps` (when provided) are + multiples of 1/fps. Defaults to 1e-4. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. Defaults to current codebase version tag. + sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + are already present in the local cache, this will be faster. However, files loaded might not + be in sync with the version on the hub, especially if you specified 'revision'. Defaults to + False. + download_videos (bool, optional): Flag to download the videos. Note that when set to True but the + video files are already present on local disk, they won't be downloaded again. Defaults to + True. + video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. + You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.delta_indices = None + + # Unused attributes + self.image_writer = None + self.episode_buffer = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + + # Load actual data + try: + if force_cache_sync: + raise FileNotFoundError + self.hf_dataset = self.load_hf_dataset() + except (AssertionError, FileNotFoundError, NotADirectoryError): + self.revision = get_safe_version(self.repo_id, self.revision) + self.download(download_videos) + self.hf_dataset = self.load_hf_dataset() + + # Setup delta_indices + if self.delta_timestamps is not None: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + def push_to_hub( + self, + branch: str | None = None, + tags: list | None = None, + license: str | None = "apache-2.0", + tag_version: bool = True, + push_videos: bool = True, + private: bool = False, + allow_patterns: list[str] | str | None = None, + upload_large_folder: bool = False, + **card_kwargs, + ) -> None: + ignore_patterns = ["images/"] + if not push_videos: + ignore_patterns.append("videos/") + + hub_api = HfApi() + hub_api.create_repo( + repo_id=self.repo_id, + private=private, + repo_type="dataset", + exist_ok=True, + ) + if branch: + hub_api.create_branch( + repo_id=self.repo_id, + branch=branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) + + upload_kwargs = { + "repo_id": self.repo_id, + "folder_path": self.root, + "repo_type": "dataset", + "revision": branch, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + } + if upload_large_folder: + hub_api.upload_large_folder(**upload_kwargs) + else: + hub_api.upload_folder(**upload_kwargs) + + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) + + if tag_version: + with contextlib.suppress(RevisionNotFoundError): + hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: + snapshot_download( + self.repo_id, + repo_type="dataset", + revision=self.revision, + local_dir=self.root, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + def download(self, download_videos: bool = True) -> None: + """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this + will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole + dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present + in 'local_dir', they won't be downloaded again. + """ + # TODO(rcadene, aliberts): implement faster transfer + # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads + ignore_patterns = None if download_videos else "videos/" + files = None + if self.episodes is not None: + files = self.get_episodes_file_paths() + self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + + def get_episodes_file_paths(self) -> list[Path]: + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self.meta.video_keys) > 0: + video_files = [ + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + # episodes are stored in the same files, so we return unique paths only + fpaths = list(set(fpaths)) + return fpaths + + def load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + hf_dataset = load_nested_dataset(self.root / "data") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def create_hf_dataset(self) -> datasets.Dataset: + features = get_hf_features_from_features(self.features) + ft_dict = {col: [] for col in features} + hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + @property + def fps(self) -> int: + """Frames per second used during data collection.""" + return self.meta.fps + + @property + def num_frames(self) -> int: + """Number of frames in selected episodes.""" + return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + + @property + def num_episodes(self) -> int: + """Number of episodes selected.""" + return len(self.episodes) if self.episodes is not None else self.meta.total_episodes + + @property + def features(self) -> dict[str, dict]: + return self.meta.features + + @property + def hf_features(self) -> datasets.Features: + """Features of the hf_dataset.""" + if self.hf_dataset is not None: + return self.hf_dataset.features + else: + return get_hf_features_from_features(self.features) + + def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + ep = self.meta.episodes[ep_idx] + ep_start = ep["dataset_from_index"] + ep_end = ep["dataset_to_index"] + query_indices = { + key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] + for key, delta_idx in self.delta_indices.items() + } + padding = { # Pad values outside of current episode range + f"{key}_is_pad": torch.BoolTensor( + [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx] + ) + for key, delta_idx in self.delta_indices.items() + } + return query_indices, padding + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset[query_indices[key]]["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: + return { + key: torch.stack(self.hf_dataset[q_idx][key]) + for key, q_idx in query_indices.items() + if key not in self.meta.video_keys + } + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + ep = self.meta.episodes[ep_idx] + item = {} + for vid_key, query_ts in query_timestamps.items(): + # Episodes are stored sequentially on a single mp4 to reduce the number of files. + # Thus we load the start timestamp of the episode on this mp4 and, + # shift the query timestamp accordingly. + from_timestamp = ep[f"videos/{vid_key}/from_timestamp"] + shifted_query_ts = [from_timestamp + ts for ts in query_ts] + + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames(video_path, shifted_query_ts, self.tolerance_s, self.video_backend) + item[vid_key] = frames.squeeze(0) + + return item + + def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict: + for key, val in padding.items(): + item[key] = torch.BoolTensor(val) + return item + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx) -> dict: + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + + query_indices = None + if self.delta_indices is not None: + query_indices, padding = self._get_query_indices(idx, ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + + if len(self.meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + if self.image_transforms is not None: + image_keys = self.meta.camera_keys + for cam in image_keys: + item[cam] = self.image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self.meta.tasks.iloc[task_idx].name + return item + + def __repr__(self): + feature_keys = list(self.features) + return ( + f"{self.__class__.__name__}({{\n" + f" Repository ID: '{self.repo_id}',\n" + f" Number of selected episodes: '{self.num_episodes}',\n" + f" Number of selected samples: '{self.num_frames}',\n" + f" Features: '{feature_keys}',\n" + "})',\n" + ) + + def create_episode_buffer(self, episode_index: int | None = None) -> dict: + current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index + ep_buffer = {} + # size and task are special cases that are not in self.features + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer + + def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + fpath = DEFAULT_IMAGE_PATH.format( + image_key=image_key, episode_index=episode_index, frame_index=frame_index + ) + return self.root / fpath + + def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path: + return self._get_image_file_path(episode_index, image_key, frame_index=0).parent + + def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + if self.image_writer is None: + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + write_image(image, fpath) + else: + self.image_writer.save_image(image=image, fpath=fpath) + + def add_frame(self, frame: dict) -> None: + """ + This function only adds the frame to the episode_buffer. Apart from images — which are written in a + temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method + then needs to be called. + """ + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self.features) + + if self.episode_buffer is None: + self.episode_buffer = self.create_episode_buffer() + + # Automatically add frame_index and timestamp to episode buffer + frame_index = self.episode_buffer["size"] + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + self.episode_buffer["frame_index"].append(frame_index) + self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing + + # Add frame features to episode_buffer + for key in frame: + if key not in self.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." + ) + + if self.features[key]["dtype"] in ["image", "video"]: + img_path = self._get_image_file_path( + episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + ) + if frame_index == 0: + img_path.parent.mkdir(parents=True, exist_ok=True) + self._save_image(frame[key], img_path) + self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) + + self.episode_buffer["size"] += 1 + + def save_episode(self, episode_data: dict | None = None) -> None: + """ + This will save to disk the current episode in self.episode_buffer. + + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. + """ + if not episode_data: + episode_buffer = self.episode_buffer + + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + + # size and task are special cases that won't be added to hf_dataset + episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) + episode_index = episode_buffer["episode_index"] + + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) + + # Update tasks and task indices with new tasks if any + self.meta.save_episode_tasks(episode_tasks) + + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) + + for key, ft in self.features.items(): + # index, episode_index, task_index are already processed above, and image and video + # are processed separately by storing image path and frame info as meta data + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: + continue + episode_buffer[key] = np.stack(episode_buffer[key]) + + # Wait for image writer to end, so that episode stats over images can be computed + self._wait_image_writer() + ep_stats = compute_episode_stats(episode_buffer, self.features) + + ep_metadata = self._save_episode_data(episode_buffer) + for video_key in self.meta.video_keys: + ep_metadata.update(self._save_episode_video(video_key, episode_index)) + + # `meta.save_episode` need to be executed after encoding the videos + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) + + # TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index + # ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) + # ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} + # check_timestamps_sync( + # episode_buffer["timestamp"], + # episode_buffer["episode_index"], + # ep_data_index_np, + # self.fps, + # self.tolerance_s, + # ) + + # TODO(rcadene): images are also deleted in clear_episode_buffer + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + + if not episode_data: + # Reset episode buffer + self.episode_buffer = self.create_episode_buffer() + + def _save_episode_data(self, episode_buffer: dict) -> dict: + """Save episode data to a parquet file and update the Hugging Face dataset of frames data. + + This function processes episodes data from a buffer, converts it into a Hugging Face dataset, + and saves it as a parquet file. It handles both the creation of new parquet files and the + updating of existing ones based on size constraints. After saving the data, it reloads + the Hugging Face dataset to ensure it is up-to-date. + + Notes: We both need to update parquet files and HF dataset: + - `pandas` loads parquet file in RAM + - `datasets` relies on a memory mapping from pyarrow (no RAM). It either converts parquet files to a pyarrow cache on disk, + or loads directly from pyarrow cache. + """ + # Convert buffer into HF Dataset + ep_dict = {key: episode_buffer[key] for key in self.hf_features} + ep_dataset = datasets.Dataset.from_dict(ep_dict, features=self.hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + ep_size_in_mb = get_hf_dataset_size_in_mb(ep_dataset) + ep_num_frames = len(ep_dataset) + df = pd.DataFrame(ep_dataset) + + if self.meta.episodes is None: + # Initialize indices and frame count for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + latest_num_frames = 0 + else: + # Retrieve information from the latest parquet file + latest_ep = self.meta.episodes[-1] + chunk_idx = latest_ep["data/chunk_index"] + file_idx = latest_ep["data/file_index"] + + latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) + latest_num_frames = get_parquet_num_frames(latest_path) + + # Determine if a new parquet file is needed + if latest_size_in_mb + ep_size_in_mb >= self.meta.data_files_size_in_mb: + # Size limit is reached, prepare new parquet file + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + latest_num_frames = 0 + else: + # Update the existing parquet file with new rows + latest_df = pd.read_parquet(latest_path) + df = pd.concat([latest_df, df], ignore_index=True) + + # Write the resulting dataframe from RAM to disk + path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + if len(self.meta.image_keys) > 0: + to_parquet_with_hf_images(df, path) + else: + df.to_parquet(path) + + # Update the Hugging Face dataset by reloading it. + # This process should be fast because only the latest Parquet file has been modified. + # Therefore, only this file needs to be converted to PyArrow; the rest is loaded from the PyArrow memory-mapped cache. + self.hf_dataset = self.load_hf_dataset() + + metadata = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": latest_num_frames, + "dataset_to_index": latest_num_frames + ep_num_frames, + } + return metadata + + def _save_episode_video(self, video_key: str, episode_index: int): + # Encode episode frames into a temporary video + ep_path = self._encode_temporary_episode_video(video_key, episode_index) + ep_size_in_mb = get_video_size_in_mb(ep_path) + ep_duration_in_s = get_video_duration_in_s(ep_path) + + if self.meta.episodes is None: + # Initialize indices for a new dataset made of the first episode data + chunk_idx, file_idx = 0, 0 + latest_duration_in_s = 0 + new_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + # Retrieve information from the latest video file + latest_ep = self.meta.episodes[-1] + chunk_idx = latest_ep[f"videos/{video_key}/chunk_index"] + file_idx = latest_ep[f"videos/{video_key}/file_index"] + + latest_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + latest_size_in_mb = get_video_size_in_mb(latest_path) + latest_duration_in_s = get_video_duration_in_s(latest_path) + + if latest_size_in_mb + ep_size_in_mb >= self.meta.video_files_size_in_mb: + # Move temporary episode video to a new video file in the dataset + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + new_path = self.root / self.meta.video_path.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + new_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(ep_path), str(new_path)) + else: + # Update latest video file + concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx) + + # Remove temporary directory + shutil.rmtree(str(ep_path.parent)) + + metadata = { + "episode_index": episode_index, + f"videos/{video_key}/chunk_index": chunk_idx, + f"videos/{video_key}/file_index": file_idx, + f"videos/{video_key}/from_timestamp": latest_duration_in_s, + f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, + } + return metadata + + def clear_episode_buffer(self) -> None: + episode_index = self.episode_buffer["episode_index"] + if self.image_writer is not None: + for cam_key in self.meta.camera_keys: + img_dir = self._get_image_file_path( + episode_index=episode_index, image_key=cam_key, frame_index=0 + ).parent + if img_dir.is_dir(): + shutil.rmtree(img_dir) + + # Reset the buffer + self.episode_buffer = self.create_episode_buffer() + + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: + if isinstance(self.image_writer, AsyncImageWriter): + logging.warning( + "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." + ) + + self.image_writer = AsyncImageWriter( + num_processes=num_processes, + num_threads=num_threads, + ) + + def stop_image_writer(self) -> None: + """ + Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to + remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized. + """ + if self.image_writer is not None: + self.image_writer.stop() + self.image_writer = None + + def _wait_image_writer(self) -> None: + """Wait for asynchronous image writer to finish.""" + if self.image_writer is not None: + self.image_writer.wait_until_done() + + def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict: + """ + Use ffmpeg to convert frames stored as png into mp4 videos. + Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, + since video encoding with ffmpeg is already using multithreading. + """ + temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4" + img_dir = self._get_image_file_dir(episode_index, video_key) + encode_video_frames(img_dir, temp_path, self.fps, overwrite=True) + return temp_path + + @classmethod + def create( + cls, + repo_id: str, + fps: int, + features: dict, + root: str | Path | None = None, + robot_type: str | None = None, + use_videos: bool = True, + tolerance_s: float = 1e-4, + image_writer_processes: int = 0, + image_writer_threads: int = 0, + video_backend: str | None = None, + ) -> "LeRobotDataset": + """Create a LeRobot Dataset from scratch in order to record data.""" + obj = cls.__new__(cls) + obj.meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=fps, + robot_type=robot_type, + features=features, + root=root, + use_videos=use_videos, + ) + obj.repo_id = obj.meta.repo_id + obj.root = obj.meta.root + obj.revision = None + obj.tolerance_s = tolerance_s + obj.image_writer = None + + if image_writer_processes or image_writer_threads: + obj.start_image_writer(image_writer_processes, image_writer_threads) + + # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer + obj.episode_buffer = obj.create_episode_buffer() + + obj.episodes = None + obj.hf_dataset = obj.create_hf_dataset() + obj.image_transforms = None + obj.delta_timestamps = None + obj.delta_indices = None + obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() + return obj + + +class MultiLeRobotDataset(torch.utils.data.Dataset): + """A dataset consisting of multiple underlying `LeRobotDataset`s. + + The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API + structure of `LeRobotDataset`. + """ + + def __init__( + self, + repo_ids: list[str], + root: str | Path | None = None, + episodes: dict | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerances_s: dict | None = None, + download_videos: bool = True, + video_backend: str | None = None, + ): + super().__init__() + self.repo_ids = repo_ids + self.root = Path(root) if root else HF_LEROBOT_HOME + self.tolerances_s = tolerances_s if tolerances_s else dict.fromkeys(repo_ids, 0.0001) + # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which + # are handled by this class. + self._datasets = [ + LeRobotDataset( + repo_id, + root=self.root / repo_id, + episodes=episodes[repo_id] if episodes else None, + image_transforms=image_transforms, + delta_timestamps=delta_timestamps, + tolerance_s=self.tolerances_s[repo_id], + download_videos=download_videos, + video_backend=video_backend, + ) + for repo_id in repo_ids + ] + + # Disable any data keys that are not common across all of the datasets. Note: we may relax this + # restriction in future iterations of this class. For now, this is necessary at least for being able + # to use PyTorch's default DataLoader collate function. + self.disabled_features = set() + intersection_features = set(self._datasets[0].features) + for ds in self._datasets: + intersection_features.intersection_update(ds.features) + if len(intersection_features) == 0: + raise RuntimeError( + "Multiple datasets were provided but they had no keys common to all of them. " + "The multi-dataset functionality currently only keeps common keys." + ) + for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): + extra_keys = set(ds.features).difference(intersection_features) + logging.warning( + f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " + "other datasets." + ) + self.disabled_features.update(extra_keys) + + self.image_transforms = image_transforms + self.delta_timestamps = delta_timestamps + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) + + @property + def repo_id_to_index(self): + """Return a mapping from dataset repo_id to a dataset index automatically created by this class. + + This index is incorporated as a data key in the dictionary returned by `__getitem__`. + """ + return {repo_id: i for i, repo_id in enumerate(self.repo_ids)} + + @property + def repo_index_to_id(self): + """Return the inverse mapping if repo_id_to_index.""" + return {v: k for k, v in self.repo_id_to_index} + + @property + def fps(self) -> int: + """Frames per second used during data collection. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info["fps"] + + @property + def video(self) -> bool: + """Returns True if this dataset loads video frames from mp4 files. + + Returns False if it only loads images from png files. + + NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info. + """ + return self._datasets[0].meta.info.get("video", False) + + @property + def features(self) -> datasets.Features: + features = {} + for dataset in self._datasets: + features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + return features + + @property + def camera_keys(self) -> list[str]: + """Keys to access image and video stream from cameras.""" + keys = [] + for key, feats in self.features.items(): + if isinstance(feats, (datasets.Image, VideoFrame)): + keys.append(key) + return keys + + @property + def video_frame_keys(self) -> list[str]: + """Keys to access video frames that requires to be decoded into images. + + Note: It is empty if the dataset contains images only, + or equal to `self.cameras` if the dataset contains videos only, + or can even be a subset of `self.cameras` in a case of a mixed image/video dataset. + """ + video_frame_keys = [] + for key, feats in self.features.items(): + if isinstance(feats, VideoFrame): + video_frame_keys.append(key) + return video_frame_keys + + @property + def num_frames(self) -> int: + """Number of samples/frames.""" + return sum(d.num_frames for d in self._datasets) + + @property + def num_episodes(self) -> int: + """Number of episodes.""" + return sum(d.num_episodes for d in self._datasets) + + @property + def tolerance_s(self) -> float: + """Tolerance in seconds used to discard loaded frames when their timestamps + are not close enough from the requested frames. It is only used when `delta_timestamps` + is provided or when loading video frames from mp4 files. + """ + # 1e-4 to account for possible numerical error + return 1 / self.fps - 1e-4 + + def __len__(self): + return self.num_frames + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Index {idx} out of bounds.") + # Determine which dataset to get an item from based on the index. + start_idx = 0 + dataset_idx = 0 + for dataset in self._datasets: + if idx >= start_idx + dataset.num_frames: + start_idx += dataset.num_frames + dataset_idx += 1 + continue + break + else: + raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + item = self._datasets[dataset_idx][idx - start_idx] + item["dataset_index"] = torch.tensor(dataset_idx) + for data_key in self.disabled_features: + if data_key in item: + del item[data_key] + + return item + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository IDs: '{self.repo_ids}',\n" + f" Number of Samples: {self.num_frames},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Camera Keys: {self.camera_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.image_transforms},\n" + f")" + ) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index e69de29bb..6223b98a2 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -0,0 +1,954 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import importlib.resources +import json +import logging +import shutil +import subprocess +import tempfile +from collections.abc import Iterator +from pathlib import Path +from pprint import pformat +from types import SimpleNamespace +from typing import Any + +import datasets +import numpy as np +import packaging.version +import pandas +import pandas as pd +import pyarrow.parquet as pq +import torch +from datasets import Dataset, concatenate_datasets +from datasets.table import embed_table_storage +from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from huggingface_hub.errors import RevisionNotFoundError +from PIL import Image as PILImage +from torchvision import transforms + +from lerobot.datasets.backward_compatibility import ( + V21_MESSAGE, + BackwardCompatibilityError, + ForwardCompatibilityError, +) +from lerobot.robots import Robot +from lerobot.utils.utils import is_valid_numpy_dtype_string +from lerobot.configs.types import FeatureType, PolicyFeature + +DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk +DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file +DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file + +INFO_PATH = "meta/info.json" +STATS_PATH = "meta/stats.json" + +EPISODES_DIR = "meta/episodes" +DATA_DIR = "data" +VIDEO_DIR = "videos" + +CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" +DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" +DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" +DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" +DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png" + +DATASET_CARD_TEMPLATE = """ +--- +# Metadata will go there +--- +This dataset was created using [LeRobot](https://github.com/huggingface/lerobot). + +## {} + +""" + +DEFAULT_FEATURES = { + "timestamp": {"dtype": "float32", "shape": (1,), "names": None}, + "frame_index": {"dtype": "int64", "shape": (1,), "names": None}, + "episode_index": {"dtype": "int64", "shape": (1,), "names": None}, + "index": {"dtype": "int64", "shape": (1,), "names": None}, + "task_index": {"dtype": "int64", "shape": (1,), "names": None}, +} + + +def get_parquet_file_size_in_mb(parquet_path): + metadata = pq.read_metadata(parquet_path) + total_uncompressed_size = 0 + for row_group in range(metadata.num_row_groups): + rg_metadata = metadata.row_group(row_group) + for column in range(rg_metadata.num_columns): + col_metadata = rg_metadata.column(column) + total_uncompressed_size += col_metadata.total_uncompressed_size + return total_uncompressed_size / (1024**2) + + +def get_hf_dataset_size_in_mb(hf_ds: Dataset) -> int: + return hf_ds.data.nbytes / (1024**2) + + +def get_pd_dataframe_size_in_mb(df: pandas.DataFrame) -> int: + # TODO(rcadene): unused? + memory_usage_bytes = df.memory_usage(deep=True).sum() + return memory_usage_bytes / (1024**2) + + +def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int): + if file_idx == chunks_size - 1: + file_idx = 0 + chunk_idx += 1 + else: + file_idx += 1 + return chunk_idx, file_idx + + +def load_nested_dataset(pq_dir: Path) -> Dataset: + """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet + Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage + Concatenate all pyarrow references to return HF Dataset format + """ + paths = sorted(pq_dir.glob("*/*.parquet")) + if len(paths) == 0: + raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") + + # TODO(rcadene): set num_proc to accelerate conversion to pyarrow + datasets = [Dataset.from_parquet(str(path)) for path in paths] + return concatenate_datasets(datasets) + + +def get_parquet_num_frames(parquet_path): + metadata = pq.read_metadata(parquet_path) + return metadata.num_rows + + +def get_video_size_in_mb(mp4_path: Path): + file_size_bytes = mp4_path.stat().st_size + file_size_mb = file_size_bytes / (1024**2) + return file_size_mb + + +def concat_video_files(paths_to_cat: list[Path], root: Path, video_key: str, chunk_idx: int, file_idx: int): + # TODO(rcadene): move to video_utils.py + # TODO(rcadene): add docstring + tmp_dir = Path(tempfile.mkdtemp(dir=root)) + # Create a text file with the list of files to concatenate + path_concat_video_files = tmp_dir / "concat_video_files.txt" + with open(path_concat_video_files, "w") as f: + for ep_path in paths_to_cat: + f.write(f"file '{str(ep_path)}'\n") + + path_tmp_output = tmp_dir / "tmp_output.mp4" + command = [ + "ffmpeg", + "-y", + "-f", + "concat", + "-safe", + "0", + "-i", + str(path_concat_video_files), + "-c", + "copy", + str(path_tmp_output), + ] + subprocess.run(command, check=True) + + output_path = root / DEFAULT_VIDEO_PATH.format( + video_key=video_key, chunk_index=chunk_idx, file_index=file_idx + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.move(str(path_tmp_output), str(output_path)) + shutil.rmtree(str(tmp_dir)) + + +def get_video_duration_in_s(mp4_file: Path): + # TODO(rcadene): move to video_utils.py + command = [ + "ffprobe", + "-v", + "error", + "-show_entries", + "format=duration", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(mp4_file), + ] + result = subprocess.run( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + return float(result.stdout) + + +def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: + """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. + + For example: + ``` + >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}` + >>> print(flatten_dict(dct)) + {"a/b": 1, "a/c/d": 2, "e": 3} + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_dict(d: dict, sep: str = "/") -> dict: + outdict = {} + for key, value in d.items(): + parts = key.split(sep) + d = outdict + for part in parts[:-1]: + if part not in d: + d[part] = {} + d = d[part] + d[parts[-1]] = value + return outdict + + +def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: + serialized_dict = {} + for key, value in flatten_dict(stats).items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + serialized_dict[key] = value.tolist() + elif isinstance(value, list) and isinstance(value[0], (int, float, list)): + serialized_dict[key] = value + elif isinstance(value, np.generic): + serialized_dict[key] = value.item() + elif isinstance(value, (int, float)): + serialized_dict[key] = value + else: + raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") + return unflatten_dict(serialized_dict) + + +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: + # Embed image bytes into the table before saving to parquet + format = dataset.format + dataset = dataset.with_format("arrow") + dataset = dataset.map(embed_table_storage, batched=False) + dataset = dataset.with_format(**format) + return dataset + + +def load_json(fpath: Path) -> Any: + with open(fpath) as f: + return json.load(f) + + +def write_json(data: dict, fpath: Path) -> None: + fpath.parent.mkdir(exist_ok=True, parents=True) + with open(fpath, "w") as f: + json.dump(data, f, indent=4, ensure_ascii=False) + + +def write_info(info: dict, local_dir: Path): + write_json(info, local_dir / INFO_PATH) + + +def load_info(local_dir: Path) -> dict: + info = load_json(local_dir / INFO_PATH) + for ft in info["features"].values(): + ft["shape"] = tuple(ft["shape"]) + return info + + +def write_stats(stats: dict, local_dir: Path): + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + return unflatten_dict(stats) + + +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_hf_dataset(hf_dataset: Dataset, local_dir: Path): + if get_hf_dataset_size_in_mb(hf_dataset) > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError("Contact a maintainer.") + + path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0) + path.parent.mkdir(parents=True, exist_ok=True) + hf_dataset.to_parquet(path) + + +def write_tasks(tasks: pandas.DataFrame, local_dir: Path): + path = local_dir / DEFAULT_TASKS_PATH + path.parent.mkdir(parents=True, exist_ok=True) + tasks.to_parquet(path) + + +def load_tasks(local_dir: Path): + tasks = pd.read_parquet(local_dir / DEFAULT_TASKS_PATH) + return tasks + + +def write_episodes(episodes: Dataset, local_dir: Path): + if get_hf_dataset_size_in_mb(episodes) > DEFAULT_DATA_FILE_SIZE_IN_MB: + raise NotImplementedError("Contact a maintainer.") + + fpath = local_dir / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + fpath.parent.mkdir(parents=True, exist_ok=True) + episodes.to_parquet(fpath) + + +def load_episodes(local_dir: Path) -> datasets.Dataset: + episodes = load_nested_dataset(local_dir / EPISODES_DIR) + # Select episode features/columns containing references to episode data and videos + # (e.g. tasks, dataset_from_index, dataset_to_index, data/chunk_index, data/file_index, etc.) + # This is to speedup access to these data, instead of having to load episode stats. + episodes = episodes.select_columns([key for key in episodes.features if not key.startswith("stats/")]) + return episodes + + +def backward_compatible_episodes_stats( + stats: dict[str, dict[str, np.ndarray]], episodes: list[int] +) -> dict[str, dict[str, np.ndarray]]: + return dict.fromkeys(episodes, stats) + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: + img = PILImage.open(fpath).convert("RGB") + img_array = np.array(img, dtype=dtype) + if channel_first: # (H, W, C) -> (C, H, W) + img_array = np.transpose(img_array, (2, 0, 1)) + if np.issubdtype(dtype, np.floating): + img_array /= 255.0 + return img_array + + +def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): + """Get a transform function that convert items from Hugging Face dataset (pyarrow) + to torch tensors. Importantly, images are converted from PIL, which corresponds to + a channel last representation (h w c) of uint8 type, to a torch image representation + with channel first (c h w) of float32 type in range [0,1]. + """ + for key in items_dict: + first_item = items_dict[key][0] + if isinstance(first_item, PILImage.Image): + to_tensor = transforms.ToTensor() + items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif first_item is None: + pass + else: + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] + return items_dict + + +def is_valid_version(version: str) -> bool: + try: + packaging.version.parse(version) + return True + except packaging.version.InvalidVersion: + return False + + +def check_version_compatibility( + repo_id: str, + version_to_check: str | packaging.version.Version, + current_version: str | packaging.version.Version, + enforce_breaking_major: bool = True, +) -> None: + v_check = ( + packaging.version.parse(version_to_check) + if not isinstance(version_to_check, packaging.version.Version) + else version_to_check + ) + v_current = ( + packaging.version.parse(current_version) + if not isinstance(current_version, packaging.version.Version) + else current_version + ) + if v_check.major < v_current.major and enforce_breaking_major: + raise BackwardCompatibilityError(repo_id, v_check) + elif v_check.minor < v_current.minor: + logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) + + +def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: + """Returns available valid versions (branches and tags) on given repo.""" + api = HfApi() + repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") + repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] + repo_versions = [] + for ref in repo_refs: + with contextlib.suppress(packaging.version.InvalidVersion): + repo_versions.append(packaging.version.parse(ref)) + + return repo_versions + + +def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: + """ + Returns the version if available on repo or the latest compatible one. + Otherwise, will throw a `CompatibilityError`. + """ + target_version = ( + packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version + ) + hub_versions = get_repo_versions(repo_id) + + if not hub_versions: + raise RevisionNotFoundError( + f"""Your dataset must be tagged with a codebase version. + Assuming _version_ is the codebase_version value in the info.json, you can run this: + ```python + from huggingface_hub import HfApi + + hub_api = HfApi() + hub_api.create_tag("{repo_id}", tag="_version_", repo_type="dataset") + ``` + """ + ) + + if target_version in hub_versions: + return f"v{target_version}" + + compatibles = [ + v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor + ] + if compatibles: + return_version = max(compatibles) + if return_version < target_version: + logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") + return f"v{return_version}" + + lower_major = [v for v in hub_versions if v.major < target_version.major] + if lower_major: + raise BackwardCompatibilityError(repo_id, max(lower_major)) + + upper_versions = [v for v in hub_versions if v > target_version] + assert len(upper_versions) > 0 + raise ForwardCompatibilityError(repo_id, min(upper_versions)) + + +def get_hf_features_from_features(features: dict) -> datasets.Features: + hf_features = {} + for key, ft in features.items(): + if ft["dtype"] == "video": + continue + elif ft["dtype"] == "image": + hf_features[key] = datasets.Image() + elif ft["shape"] == (1,): + hf_features[key] = datasets.Value(dtype=ft["dtype"]) + elif len(ft["shape"]) == 1: + hf_features[key] = datasets.Sequence( + length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) + ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") + + return datasets.Features(hf_features) + + +def _validate_feature_names(features: dict[str, dict]) -> None: + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + features = {} + joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == "action": + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == "observation": + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + +def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: + # TODO(rcadene): add fps for each feature + camera_ft = {} + if robot.cameras: + camera_ft = { + key: {"dtype": "video" if use_videos else "image", **ft} + for key, ft in robot.camera_features.items() + } + return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES} + + +def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFeature]: + # TODO(aliberts): Implement "type" in dataset features and simplify this + policy_features = {} + for key, ft in features.items(): + shape = ft["shape"] + if ft["dtype"] in ["image", "video"]: + type = FeatureType.VISUAL + if len(shape) != 3: + raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})") + + names = ft["names"] + # Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets. + if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif key == "observation.environment_state": + type = FeatureType.ENV + elif key.startswith("observation"): + type = FeatureType.STATE + elif key.startswith("action"): + type = FeatureType.ACTION + else: + continue + + policy_features[key] = PolicyFeature( + type=type, + shape=shape, + ) + + return policy_features + + +def create_empty_dataset_info( + codebase_version: str, + fps: int, + features: dict, + use_videos: bool, + robot_type: str | None = None, +) -> dict: + return { + "codebase_version": codebase_version, + "robot_type": robot_type, + "total_episodes": 0, + "total_frames": 0, + "total_tasks": 0, + "chunks_size": DEFAULT_CHUNK_SIZE, + "data_files_size_in_mb": DEFAULT_DATA_FILE_SIZE_IN_MB, + "video_files_size_in_mb": DEFAULT_VIDEO_FILE_SIZE_IN_MB, + "fps": fps, + "splits": {}, + "data_path": DEFAULT_DATA_PATH, + "video_path": DEFAULT_VIDEO_PATH if use_videos else None, + "features": features, + } + + +def check_timestamps_sync( + timestamps: np.ndarray, + episode_indices: np.ndarray, + episode_data_index: dict[str, np.ndarray], + fps: int, + tolerance_s: float, + raise_value_error: bool = True, +) -> bool: + """ + This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance + to account for possible numerical error. + + Args: + timestamps (np.ndarray): Array of timestamps in seconds. + episode_indices (np.ndarray): Array indicating the episode index for each timestamp. + episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', + which identifies indices for the end of each episode. + fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. + tolerance_s (float): Allowed deviation from the expected (1/fps) difference. + raise_value_error (bool): Whether to raise a ValueError if the check fails. + + Returns: + bool: True if all checked timestamp differences lie within tolerance, False otherwise. + + Raises: + ValueError: If the check fails and `raise_value_error` is True. + """ + if timestamps.shape != episode_indices.shape: + raise ValueError( + "timestamps and episode_indices should have the same shape. " + f"Found {timestamps.shape=} and {episode_indices.shape=}." + ) + + # Consecutive differences + diffs = np.diff(timestamps) + within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s + + # Mask to ignore differences at the boundaries between episodes + mask = np.ones(len(diffs), dtype=bool) + ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode + mask[ignored_diffs] = False + filtered_within_tolerance = within_tolerance[mask] + + # Check if all remaining diffs are within tolerance + if not np.all(filtered_within_tolerance): + # Track original indices before masking + original_indices = np.arange(len(diffs)) + filtered_indices = original_indices[mask] + outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] + outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] + + outside_tolerances = [] + for idx in outside_tolerance_indices: + entry = { + "timestamps": [timestamps[idx], timestamps[idx + 1]], + "diff": diffs[idx], + "episode_index": episode_indices[idx].item() + if hasattr(episode_indices[idx], "item") + else episode_indices[idx], + } + outside_tolerances.append(entry) + + if raise_value_error: + raise ValueError( + f"""One or several timestamps unexpectedly violate the tolerance inside episode range. + This might be due to synchronization issues during data collection. + \n{pformat(outside_tolerances)}""" + ) + return False + + return True + + +def check_delta_timestamps( + delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True +) -> bool: + """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. + This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be + actual timestamps from the dataset. + """ + outside_tolerance = {} + for key, delta_ts in delta_timestamps.items(): + within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + if not all(within_tolerance): + outside_tolerance[key] = [ + ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ] + + if len(outside_tolerance) > 0: + if raise_value_error: + raise ValueError( + f""" + The following delta_timestamps are found outside of tolerance range. + Please make sure they are multiples of 1/{fps} +/- tolerance and adjust + their values accordingly. + \n{pformat(outside_tolerance)} + """ + ) + return False + + return True + + +def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: + delta_indices = {} + for key, delta_ts in delta_timestamps.items(): + delta_indices[key] = [round(d * fps) for d in delta_ts] + + return delta_indices + + +def cycle(iterable): + """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + + See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + """ + iterator = iter(iterable) + while True: + try: + yield next(iterator) + except StopIteration: + iterator = iter(iterable) + + +def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None: + """Create a branch on a existing Hugging Face repo. Delete the branch if it already + exists before creating it. + """ + api = HfApi() + + branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches + refs = [branch.ref for branch in branches] + ref = f"refs/heads/{branch}" + if ref in refs: + api.delete_branch(repo_id, repo_type=repo_type, branch=branch) + + api.create_branch(repo_id, repo_type=repo_type, branch=branch) + + +def create_lerobot_dataset_card( + tags: list | None = None, + dataset_info: dict | None = None, + **kwargs, +) -> DatasetCard: + """ + Keyword arguments will be used to replace values in ./lerobot/datasets/card_template.md. + Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. + """ + card_tags = ["LeRobot"] + + if tags: + card_tags += tags + if dataset_info: + dataset_structure = "[meta/info.json](meta/info.json):\n" + dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n" + kwargs = {**kwargs, "dataset_structure": dataset_structure} + card_data = DatasetCardData( + license=kwargs.get("license"), + tags=card_tags, + task_categories=["robotics"], + configs=[ + { + "config_name": "default", + "data_files": "data/*/*.parquet", + } + ], + ) + + card_template = (importlib.resources.files("lerobot.datasets") / "card_template.md").read_text() + + return DatasetCard.from_template( + card_data=card_data, + template_str=card_template, + **kwargs, + ) + + +class IterableNamespace(SimpleNamespace): + """ + A namespace object that supports both dictionary-like iteration and dot notation access. + Automatically converts nested dictionaries into IterableNamespaces. + + This class extends SimpleNamespace to provide: + - Dictionary-style iteration over keys + - Access to items via both dot notation (obj.key) and brackets (obj["key"]) + - Dictionary-like methods: items(), keys(), values() + - Recursive conversion of nested dictionaries + + Args: + dictionary: Optional dictionary to initialize the namespace + **kwargs: Additional keyword arguments passed to SimpleNamespace + + Examples: + >>> data = {"name": "Alice", "details": {"age": 25}} + >>> ns = IterableNamespace(data) + >>> ns.name + 'Alice' + >>> ns.details.age + 25 + >>> list(ns.keys()) + ['name', 'details'] + >>> for key, value in ns.items(): + ... print(f"{key}: {value}") + name: Alice + details: IterableNamespace(age=25) + """ + + def __init__(self, dictionary: dict[str, Any] = None, **kwargs): + super().__init__(**kwargs) + if dictionary is not None: + for key, value in dictionary.items(): + if isinstance(value, dict): + setattr(self, key, IterableNamespace(value)) + else: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(vars(self)) + + def __getitem__(self, key: str) -> Any: + return vars(self)[key] + + def items(self): + return vars(self).items() + + def values(self): + return vars(self).values() + + def keys(self): + return vars(self).keys() + + +def validate_frame(frame: dict, features: dict): + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) + + # task is a special required field that's not part of regular features + if "task" not in actual_features: + raise ValueError("Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n") + + # Remove task from actual_features for regular feature validation + actual_features_for_validation = actual_features - {"task"} + + error_message = validate_features_presence(actual_features_for_validation, expected_features) + + common_features = actual_features_for_validation & expected_features + for name in common_features: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence(actual_features: set[str], expected_features: set[str]): + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - expected_features + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +): + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str): + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) + + +def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path): + """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. + This way, it can be loaded by HF dataset and correctly formatted images are returned. + """ + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) diff --git a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py index e69de29bb..27cd56e6f 100644 --- a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py +++ b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py @@ -0,0 +1,114 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to +2.1. It will: + +- Generate per-episodes stats and writes them in `episodes_stats.jsonl` +- Check consistency between these new stats and the old ones. +- Remove the deprecated `stats.json`. +- Update codebase_version in `info.json`. +- Push this new version to the hub on the 'main' branch and tags it with "v2.1". + +Usage: + +```bash +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ + --repo-id=aliberts/koch_tutorial +``` + +""" + +import argparse +import logging + +from huggingface_hub import HfApi + +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.datasets.utils import LEGACY_EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info +from lerobot.datasets.v21.convert_stats import check_aggregate_stats, convert_stats + +V20 = "v2.0" +V21 = "v2.1" + + +class SuppressWarnings: + def __enter__(self): + self.previous_level = logging.getLogger().getEffectiveLevel() + logging.getLogger().setLevel(logging.ERROR) + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.getLogger().setLevel(self.previous_level) + + +def convert_dataset( + repo_id: str, + branch: str | None = None, + num_workers: int = 4, +): + with SuppressWarnings(): + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + + if (dataset.root / LEGACY_EPISODES_STATS_PATH).is_file(): + (dataset.root / LEGACY_EPISODES_STATS_PATH).unlink() + + convert_stats(dataset, num_workers=num_workers) + ref_stats = load_stats(dataset.root) + check_aggregate_stats(dataset, ref_stats) + + dataset.meta.info["codebase_version"] = CODEBASE_VERSION + write_info(dataset.meta.info, dataset.root) + + dataset.push_to_hub(branch=branch, tag_version=False, allow_patterns="meta/") + + # delete old stats.json file + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() + + hub_api = HfApi() + if hub_api.file_exists( + repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + ): + hub_api.delete_file( + path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + ) + + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " + "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--branch", + type=str, + default=None, + help="Repo branch to push your dataset. Defaults to the main branch.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for parallelizing stats compute. Defaults to 4.", + ) + + args = parser.parse_args() + convert_dataset(**vars(args)) diff --git a/src/lerobot/datasets/v21/convert_stats.py b/src/lerobot/datasets/v21/convert_stats.py index e69de29bb..a6b4f4afd 100644 --- a/src/lerobot/datasets/v21/convert_stats.py +++ b/src/lerobot/datasets/v21/convert_stats.py @@ -0,0 +1,99 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +from tqdm import tqdm + +from lerobot.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import legacy_write_episode_stats + + +def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: + ep_len = dataset.meta.episodes[episode_index]["length"] + sampled_indices = sample_indices(ep_len) + query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) + video_frames = dataset._query_videos(query_timestamps, episode_index) + return video_frames[ft_key].numpy() + + +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): + ep_start_idx = dataset.episode_data_index["from"][ep_idx] + ep_end_idx = dataset.episode_data_index["to"][ep_idx] + ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) + + ep_stats = {} + for key, ft in dataset.features.items(): + if ft["dtype"] == "video": + # We sample only for videos + ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) + else: + ep_ft_data = np.array(ep_data[key]) + + axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 + keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 + ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) + + if ft["dtype"] in ["image", "video"]: # remove batch dim + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } + + dataset.meta.episodes_stats[ep_idx] = ep_stats + + +def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + if num_workers > 0: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx + for ep_idx in range(total_episodes) + } + for future in tqdm(as_completed(futures), total=total_episodes): + future.result() + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + legacy_write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + +def check_aggregate_stats( + dataset: LeRobotDataset, + reference_stats: dict[str, dict[str, np.ndarray]], + video_rtol_atol: tuple[float] = (1e-2, 1e-2), + default_rtol_atol: tuple[float] = (5e-6, 6e-5), +): + """Verifies that the aggregated stats from episodes_stats are close to reference stats.""" + agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) + for key, ft in dataset.features.items(): + # These values might need some fine-tuning + if ft["dtype"] == "video": + # to account for image sub-sampling + rtol, atol = video_rtol_atol + else: + rtol, atol = default_rtol_atol + + for stat, val in agg_stats[key].items(): + if key in reference_stats and stat in reference_stats[key]: + err_msg = f"feature='{key}' stats='{stat}'" + np.testing.assert_allclose( + val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg + ) diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 1e87d19a3..739a87786 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -11,7 +11,7 @@ This script will help you convert any LeRobot dataset already pushed to the hub Usage: ```bash -python lerobot/common/datasets/v30/convert_dataset_v21_to_v30.py \ +python lerobot/datasets/v30/convert_dataset_v21_to_v30.py \ --repo-id=lerobot/pusht ``` @@ -30,10 +30,10 @@ from datasets import Dataset, Features, Image from huggingface_hub import HfApi, snapshot_download from requests import HTTPError -from lerobot.common.constants import HF_LEROBOT_HOME -from lerobot.common.datasets.compute_stats import aggregate_stats -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset -from lerobot.common.datasets.utils import ( +from lerobot.constants import HF_LEROBOT_HOME +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index e69de29bb..b0f6c15c2 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import json +import logging +import subprocess +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, ClassVar + +import pyarrow as pa +import torch +import torchvision +from datasets.features.features import register_feature +from PIL import Image + + +def get_safe_default_codec(): + if importlib.util.find_spec("torchcodec"): + return "torchcodec" + else: + logging.warning( + "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" + ) + return "pyav" + + +def decode_video_frames( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str | None = None, +) -> torch.Tensor: + """ + Decodes video frames using the specified backend. + + Args: + video_path (Path): Path to the video file. + timestamps (list[float]): List of timestamps to extract frames. + tolerance_s (float): Allowed deviation in seconds for frame retrieval. + backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".. + + Returns: + torch.Tensor: Decoded frames. + + Currently supports torchcodec on cpu and pyav. + """ + if backend is None: + backend = get_safe_default_codec() + if backend == "torchcodec": + return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s) + elif backend in ["pyav", "video_reader"]: + return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + else: + raise ValueError(f"Unsupported video backend: {backend}") + + +def decode_video_frames_torchvision( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + backend: str = "pyav", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated to the requested timestamps of a video + + The backend can be either "pyav" (default) or "video_reader". + "video_reader" requires installing torchvision from source, see: + https://github.com/pytorch/vision/blob/main/torchvision/csrc/io/decoder/gpu/README.rst + (note that you need to compile against ffmpeg<4.3) + + While both use cpu, "video_reader" is supposedly faster than "pyav" but requires additional setup. + For more info on video decoding, see `benchmark/video/README.md` + + See torchvision doc for more info on these two backends: + https://pytorch.org/vision/0.18/index.html?highlight=backend#torchvision.set_video_backend + + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, + the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to + that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, + and all subsequent frames until reaching the requested frame. The number of key frames in a video + can be adjusted during encoding to take into account decoding time and video size in bytes. + """ + video_path = str(video_path) + + # set backend + keyframes_only = False + torchvision.set_video_backend(backend) + if backend == "pyav": + keyframes_only = True # pyav doesnt support accuracte seek + + # set a video stream reader + # TODO(rcadene): also load audio stream at the same time + reader = torchvision.io.VideoReader(video_path, "video") + + # set the first and last requested timestamps + # Note: previous timestamps are usually loaded, since we need to access the previous key frame + first_ts = min(timestamps) + last_ts = max(timestamps) + + # access closest key frame of the first requested frame + # Note: closest key frame timestamp is usually smaller than `first_ts` (e.g. key frame can be the first frame of the video) + # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek + reader.seek(first_ts, keyframes_only=keyframes_only) + + # load all frames until last requested frame + loaded_frames = [] + loaded_ts = [] + for frame in reader: + current_ts = frame["pts"] + if log_loaded_timestamps: + logging.info(f"frame loaded at timestamp={current_ts:.4f}") + loaded_frames.append(frame["data"]) + loaded_ts.append(current_ts) + if current_ts >= last_ts: + break + + if backend == "pyav": + reader.container.close() + + reader = None + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and timestamps of all loaded frames + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + f"\nbackend: {backend}" + ) + + # get closest frames to the query timestamps + # TODO(rcadene): remove torch.stack + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to the pytorch format which is float32 in [0,1] range (and channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + + +def decode_video_frames_torchcodec( + video_path: Path | str, + timestamps: list[float], + tolerance_s: float, + device: str = "cpu", + log_loaded_timestamps: bool = False, +) -> torch.Tensor: + """Loads frames associated with the requested timestamps of a video using torchcodec. + + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. + + Note: Video benefits from inter-frame compression. Instead of storing every frame individually, + the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to + that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame, + and all subsequent frames until reaching the requested frame. The number of key frames in a video + can be adjusted during encoding to take into account decoding time and video size in bytes. + """ + + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + + # initialize video decoder + decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") + loaded_frames = [] + loaded_ts = [] + # get metadata for frame information + metadata = decoder.metadata + average_fps = metadata.average_fps + + # convert timestamps to frame indices + frame_indices = [round(ts * average_fps) for ts in timestamps] + + # retrieve frames based on indices + frames_batch = decoder.get_frames_at(indices=frame_indices) + + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): + loaded_frames.append(frame) + loaded_ts.append(pts.item()) + if log_loaded_timestamps: + logging.info(f"Frame loaded at timestamp={pts:.4f}") + + query_ts = torch.tensor(timestamps) + loaded_ts = torch.tensor(loaded_ts) + + # compute distances between each query timestamp and loaded timestamps + dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) + min_, argmin_ = dist.min(1) + + is_within_tol = min_ < tolerance_s + assert is_within_tol.all(), ( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + "It means that the closest frame that can be loaded from the video is too far away in time." + "This might be due to synchronization issues with timestamps during data collection." + "To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) + + # get closest frames to the query timestamps + closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) + closest_ts = loaded_ts[argmin_] + + if log_loaded_timestamps: + logging.info(f"{closest_ts=}") + + # convert to float32 in [0,1] range (channel first) + closest_frames = closest_frames.type(torch.float32) / 255 + + assert len(timestamps) == len(closest_frames) + return closest_frames + + +def encode_video_frames( + imgs_dir: Path | str, + video_path: Path | str, + fps: int, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int | None = 2, + crf: int | None = 30, + fast_decode: int = 0, + log_level: str | None = "quiet", + overwrite: bool = False, +) -> None: + """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" + video_path = Path(video_path) + imgs_dir = Path(imgs_dir) + video_path.parent.mkdir(parents=True, exist_ok=True) + + ffmpeg_args = OrderedDict( + [ + ("-f", "image2"), + ("-r", str(fps)), + ("-i", str(imgs_dir / "frame-%06d.png")), + ("-vcodec", vcodec), + ("-pix_fmt", pix_fmt), + ] + ) + + if g is not None: + ffmpeg_args["-g"] = str(g) + + if crf is not None: + ffmpeg_args["-crf"] = str(crf) + + if fast_decode: + key = "-svtav1-params" if vcodec == "libsvtav1" else "-tune" + value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" + ffmpeg_args[key] = value + + if log_level is not None: + ffmpeg_args["-loglevel"] = str(log_level) + + ffmpeg_args = [item for pair in ffmpeg_args.items() for item in pair] + if overwrite: + ffmpeg_args.append("-y") + + ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) + + if not video_path.exists(): + raise OSError( + f"Video encoding did not work. File not found: {video_path}. " + f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`" + ) + + +@dataclass +class VideoFrame: + # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()}) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "'register_feature' is experimental and might be subject to breaking changes in the future.", + category=UserWarning, + ) + # to make VideoFrame available in HuggingFace `datasets` + register_feature(VideoFrame, "VideoFrame") + + +def get_audio_info(video_path: Path | str) -> dict: + ffprobe_audio_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "a:0", + "-show_entries", + "stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + audio_stream_info = info["streams"][0] if info.get("streams") else None + if audio_stream_info is None: + return {"has_audio": False} + + # Return the information, defaulting to None if no audio stream is present + return { + "has_audio": True, + "audio.channels": audio_stream_info.get("channels", None), + "audio.codec": audio_stream_info.get("codec_name", None), + "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.sample_rate": int(audio_stream_info["sample_rate"]) + if audio_stream_info.get("sample_rate") + else None, + "audio.bit_depth": audio_stream_info.get("bit_depth", None), + "audio.channel_layout": audio_stream_info.get("channel_layout", None), + } + + +def get_video_info(video_path: Path | str) -> dict: + ffprobe_video_cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt", + "-of", + "json", + str(video_path), + ] + result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + if result.returncode != 0: + raise RuntimeError(f"Error running ffprobe: {result.stderr}") + + info = json.loads(result.stdout) + video_stream_info = info["streams"][0] + + # Calculate fps from r_frame_rate + r_frame_rate = video_stream_info["r_frame_rate"] + num, denom = map(int, r_frame_rate.split("/")) + fps = num / denom + + pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + + video_info = { + "video.fps": fps, + "video.height": video_stream_info["height"], + "video.width": video_stream_info["width"], + "video.channels": pixel_channels, + "video.codec": video_stream_info["codec_name"], + "video.pix_fmt": video_stream_info["pix_fmt"], + "video.is_depth_map": False, + **get_audio_info(video_path), + } + + return video_info + + +def get_video_pixel_channels(pix_fmt: str) -> int: + if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: + return 1 + elif "rgba" in pix_fmt or "yuva" in pix_fmt: + return 4 + elif "rgb" in pix_fmt or "yuv" in pix_fmt: + return 3 + else: + raise ValueError("Unknown format") + + +def get_image_pixel_channels(image: Image): + if image.mode == "L": + return 1 # Grayscale + elif image.mode == "LA": + return 2 # Grayscale + Alpha + elif image.mode == "RGB": + return 3 # RGB + elif image.mode == "RGBA": + return 4 # RGBA + else: + raise ValueError("Unknown format") diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index e69de29bb..b7c104cf6 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import os.path as osp +import platform +import select +import subprocess +import sys +import time +from copy import copy, deepcopy +from datetime import datetime, timezone +from pathlib import Path +from statistics import mean + +import numpy as np +import torch + + +def none_or_int(value): + if value == "None": + return None + return int(value) + + +def inside_slurm(): + """Check whether the python process was launched through slurm""" + # TODO(rcadene): return False for interactive mode `--pty bash` + return "SLURM_JOB_ID" in os.environ + + +def auto_select_torch_device() -> torch.device: + """Tries to select automatically a torch device.""" + if torch.cuda.is_available(): + logging.info("Cuda backend detected, using cuda.") + return torch.device("cuda") + elif torch.backends.mps.is_available(): + logging.info("Metal backend detected, using cuda.") + return torch.device("mps") + else: + logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") + return torch.device("cpu") + + +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level +def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) + match try_device: + case "cuda": + assert torch.cuda.is_available() + device = torch.device("cuda") + case "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + case "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + case _: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") + + return device + + +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + else: + return dtype + + +def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string + if try_device == "cuda": + return torch.cuda.is_available() + elif try_device == "mps": + return torch.backends.mps.is_available() + elif try_device == "cpu": + return True + else: + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") + + +def is_amp_available(device: str): + if device in ["cuda", "cpu"]: + return True + elif device == "mps": + return False + else: + raise ValueError(f"Unknown device '{device}.") + + +def init_logging(log_file: Path | None = None, display_pid: bool = False): + def custom_format(record): + dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + fnameline = f"{record.pathname}:{record.lineno}" + + # NOTE: Display PID is useful for multi-process logging. + if display_pid: + pid_str = f"[PID: {os.getpid()}]" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + else: + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + return message + + logging.basicConfig(level=logging.INFO) + + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + formatter = logging.Formatter() + formatter.format = custom_format + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logging.getLogger().addHandler(console_handler) + + if log_file is not None: + # Additionally write logs to file + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + + +def format_big_number(num, precision=0): + suffixes = ["", "K", "M", "B", "T", "Q"] + divisor = 1000.0 + + for suffix in suffixes: + if abs(num) < divisor: + return f"{num:.{precision}f}{suffix}" + num /= divisor + + return num + + +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + +def print_cuda_memory_usage(): + """Use this function to locate and debug memory leak.""" + import gc + + gc.collect() + # Also clear the cache if you want to fully release the memory + torch.cuda.empty_cache() + print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) + print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) + print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) + print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) + + +def capture_timestamp_utc(): + return datetime.now(timezone.utc) + + +def say(text: str, blocking: bool = False): + system = platform.system() + + if system == "Darwin": + cmd = ["say", text] + + elif system == "Linux": + cmd = ["spd-say", text] + if blocking: + cmd.append("--wait") + + elif system == "Windows": + cmd = [ + "PowerShell", + "-Command", + "Add-Type -AssemblyName System.Speech; " + f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')", + ] + + else: + raise RuntimeError("Unsupported operating system for text-to-speech.") + + if blocking: + subprocess.run(cmd, check=True) + else: + subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0) + + +def log_say(text: str, play_sounds: bool = True, blocking: bool = False): + logging.info(text) + + if play_sounds: + say(text, blocking) + + +def get_channel_first_image_shape(image_shape: tuple) -> tuple: + shape = copy(image_shape) + if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w) + shape = (shape[2], shape[0], shape[1]) + elif not (shape[0] < shape[1] and shape[0] < shape[2]): + raise ValueError(image_shape) + + return shape + + +def has_method(cls: object, method_name: str) -> bool: + return hasattr(cls, method_name) and callable(getattr(cls, method_name)) + + +def is_valid_numpy_dtype_string(dtype_str: str) -> bool: + """ + Return True if a given string can be converted to a numpy dtype. + """ + try: + # Attempt to convert the string to a numpy dtype + np.dtype(dtype_str) + return True + except TypeError: + # If a TypeError is raised, the string is not a valid dtype + return False + + +def enter_pressed() -> bool: + if platform.system() == "Windows": + import msvcrt + + if msvcrt.kbhit(): + key = msvcrt.getch() + return key in (b"\r", b"\n") # enter key + return False + else: + return select.select([sys.stdin], [], [], 0)[0] and sys.stdin.readline().strip() == "" + + +def move_cursor_up(lines): + """Move the cursor up by a specified number of lines.""" + print(f"\033[{lines}A", end="") + + +def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): + days = int(elapsed_time_s // (24 * 3600)) + elapsed_time_s %= 24 * 3600 + hours = int(elapsed_time_s // 3600) + elapsed_time_s %= 3600 + minutes = int(elapsed_time_s // 60) + seconds = elapsed_time_s % 60 + return days, hours, minutes, seconds + + +class TimerManager: + """ + Lightweight utility to measure elapsed time. + + Examples + -------- + ```python + # Example 1: Using context manager + timer = TimerManager("Policy", log=False) + for _ in range(3): + with timer: + time.sleep(0.01) + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + + ```python + # Example 2: Using start/stop methods + timer = TimerManager("Policy", log=False) + timer.start() + time.sleep(0.01) + timer.stop() + print(timer.last, timer.fps_avg, timer.percentile(90)) # Prints: 0.01 100.0 0.01 + ``` + """ + + def __init__( + self, + label: str = "Elapsed-time", + log: bool = True, + logger: logging.Logger | None = None, + ): + self.label = label + self.log = log + self.logger = logger + self._start: float | None = None + self._history: list[float] = [] + + def __enter__(self): + return self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + def start(self): + self._start = time.perf_counter() + return self + + def stop(self) -> float: + if self._start is None: + raise RuntimeError("Timer was never started.") + elapsed = time.perf_counter() - self._start + self._history.append(elapsed) + self._start = None + if self.log: + if self.logger is not None: + self.logger.info(f"{self.label}: {elapsed:.6f} s") + else: + logging.info(f"{self.label}: {elapsed:.6f} s") + return elapsed + + def reset(self): + self._history.clear() + + @property + def last(self) -> float: + return self._history[-1] if self._history else 0.0 + + @property + def avg(self) -> float: + return mean(self._history) if self._history else 0.0 + + @property + def total(self) -> float: + return sum(self._history) + + @property + def count(self) -> int: + return len(self._history) + + @property + def history(self) -> list[float]: + return deepcopy(self._history) + + @property + def fps_history(self) -> list[float]: + return [1.0 / t for t in self._history] + + @property + def fps_last(self) -> float: + return 0.0 if self.last == 0 else 1.0 / self.last + + @property + def fps_avg(self) -> float: + return 0.0 if self.avg == 0 else 1.0 / self.avg + + def percentile(self, p: float) -> float: + """ + Return the p-th percentile of recorded times. + """ + if not self._history: + return 0.0 + return float(np.percentile(self._history, p)) + + def fps_percentile(self, p: float) -> float: + """ + FPS corresponding to the p-th percentile time. + """ + val = self.percentile(p) + return 0.0 if val == 0 else 1.0 / val diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index b67ed001e..6a1b3b9ff 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -1,7 +1,7 @@ import torch -from lerobot.common.datasets.aggregate import aggregate_datasets -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.lerobot_dataset import LeRobotDataset from tests.fixtures.constants import DUMMY_REPO_ID diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 1cac7363f..905cb23cf 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -26,19 +26,19 @@ from PIL import Image from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.image_writer import image_array_to_pil_image -from lerobot.common.datasets.lerobot_dataset import ( +from lerobot.datasets.factory import make_dataset +from lerobot.datasets.image_writer import image_array_to_pil_image +from lerobot.datasets.lerobot_dataset import ( LeRobotDataset, MultiLeRobotDataset, ) -from lerobot.common.datasets.utils import ( +from lerobot.datasets.utils import ( create_branch, hw_to_dataset_features, ) -from lerobot.common.envs.factory import make_env_config -from lerobot.common.policies.factory import make_policy_config -from lerobot.common.robots import make_robot_from_config +from lerobot.envs.factory import make_env_config +from lerobot.policies.factory import make_policy_config +from lerobot.robots import make_robot_from_config from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -107,7 +107,7 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory): # and test the small resulting function that validates the features def test_dataset_feature_with_forward_slash_raises_error(): # make sure dir does not exist - from lerobot.common.constants import HF_LEROBOT_HOME + from lerobot.constants import HF_LEROBOT_HOME dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" # make sure does not exist diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index d778f4412..91d661b3c 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -21,8 +21,8 @@ import torch from datasets import Dataset from huggingface_hub import DatasetCard -from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index -from lerobot.common.datasets.utils import ( +from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.datasets.utils import ( create_lerobot_dataset_card, flatten_dict, hf_transform_to_torch, diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 253ab1a90..6441f8303 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -19,7 +19,7 @@ import pyarrow.compute as pc import pyarrow.parquet as pq import pytest -from lerobot.common.datasets.utils import ( +from lerobot.datasets.utils import ( write_episodes, write_hf_dataset, write_info, diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index c218d592d..97f77158a 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -18,7 +18,7 @@ import pandas as pd import pytest from huggingface_hub.utils import filter_repo_objects -from lerobot.common.datasets.utils import ( +from lerobot.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index bd6c99801..57ac0edc1 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -23,20 +23,20 @@ import torch from safetensors.torch import load_file from lerobot import available_policies -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.datasets.utils import cycle, dataset_to_policy_features -from lerobot.common.envs.factory import make_env, make_env_config -from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.optim.factory import make_optimizer_and_scheduler -from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler -from lerobot.common.policies.factory import ( +from lerobot.datasets.factory import make_dataset +from lerobot.datasets.utils import cycle, dataset_to_policy_features +from lerobot.envs.factory import make_env, make_env_config +from lerobot.envs.utils import preprocess_observation +from lerobot.optim.factory import make_optimizer_and_scheduler +from lerobot.policies.act.modeling_act import ACTTemporalEnsembler +from lerobot.policies.factory import ( get_policy_class, make_policy, make_policy_config, ) -from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.utils.random_utils import seeded_context +from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature