Compare commits

...

10 Commits

Author SHA1 Message Date
Steven Palma bcc13f1d90 try fix 9 2025-11-05 21:52:15 +01:00
Steven Palma 76f25f6afd try fix 8 2025-11-05 21:49:04 +01:00
Steven Palma ce23681d4b try fix 7 2025-11-05 21:46:09 +01:00
Steven Palma e195f8d287 try fix 6 2025-11-05 21:42:31 +01:00
Steven Palma bbcffc4999 try fix 5 2025-11-05 21:34:10 +01:00
Steven Palma 20333abc72 try fix 4 2025-11-05 21:26:52 +01:00
Steven Palma 00a4e6bfb3 try fix 3 2025-11-05 21:09:53 +01:00
Steven Palma a19bd6e84d try fix 3 2025-11-05 21:08:23 +01:00
Steven Palma 550866a3c5 try fix 2 2025-11-05 20:49:29 +01:00
Steven Palma 3ec4e4ce37 try fix 2025-11-05 20:24:47 +01:00
2 changed files with 143 additions and 46 deletions
+136 -16
View File
@@ -19,6 +19,7 @@ import shutil
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
import datasets
import numpy as np
@@ -31,6 +32,8 @@ import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
@@ -50,11 +53,9 @@ from lerobot.datasets.utils import (
get_file_size_in_mb,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_info,
load_nested_dataset,
load_stats,
load_tasks,
update_chunk_file_indices,
@@ -79,6 +80,51 @@ from lerobot.utils.constants import HF_LEROBOT_HOME
CODEBASE_VERSION = "v3.0"
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""
Converts a batch from a Hugging Face dataset to torch tensors.
"""
# Create a single ToTensor transform instance to reuse
to_tensor = transforms.ToTensor()
for key in items_dict:
items_list = items_dict[key]
# Check if the list is non-empty
if not items_list:
continue
first_item = items_list[0]
if isinstance(first_item, PILImage.Image):
# This is the (slow) CPU-bound part.
# We convert every image in the batch list to a tensor.
items_dict[key] = [to_tensor(img) for img in items_list]
elif isinstance(first_item, (str, bytes)):
# List of strings (e.g., 'task'), do nothing
pass
elif first_item is None:
# List of Nones, do nothing
pass
else:
# List of other things (int, float, list, np.ndarray)
try:
# Convert each item in the list to a tensor
items_dict[key] = [torch.tensor(item) for item in items_list]
except Exception as e:
# This catch is what was missing from the original v3.0 code
print(
f"Error converting batch['{key}'] to tensor. First item: {first_item}, Type: {type(first_item)}"
)
raise e
return items_dict
class LeRobotDatasetMetadata:
def __init__(
self,
@@ -693,6 +739,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
# Pre-load episodes metadata into memory to avoid file I/O in __getitem__
self.episodes_metadata_list = [ep for ep in self.meta.episodes]
# Track dataset state for efficient incremental writing
self._lazy_loading = False
self._recorded_frames = self.meta.total_frames
@@ -829,8 +878,36 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
if self.episodes is not None:
# Path for episode-specific loading (e.g., visualization)
fpaths = set()
for ep_idx in self.episodes:
ep_meta = self.episodes_metadata_list[ep_idx]
chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"]
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
fpaths.add(str(self.root / fpath_str))
data_files = sorted(list(fpaths))
hf_dataset = datasets.load_dataset(
"parquet", data_files=data_files, features=features, split="train"
)
requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
)
else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# Use `data_dir` to trigger the v2.1-style efficient cache.
data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -909,7 +986,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return get_hf_features_from_features(self.features)
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
ep = self.meta.episodes[ep_idx]
ep = self.episodes_metadata_list[ep_idx]
ep_start = ep["dataset_from_index"]
ep_end = ep["dataset_to_index"]
query_indices = {
@@ -952,7 +1029,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
ep = self.meta.episodes[ep_idx]
ep = self.episodes_metadata_list[ep_idx]
item = {}
for vid_key, query_ts in query_timestamps.items():
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
@@ -983,29 +1060,72 @@ class LeRobotDataset(torch.utils.data.Dataset):
def __getitem__(self, idx) -> dict:
# Ensure dataset is loaded when we actually need to read from it
self._ensure_hf_dataset_loaded()
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 1. Get query indices if deltas are needed
query_indices = None
padding = {}
if self.delta_indices is not None:
query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
# We need the episode index *first* to get boundaries.
# This is a small read for just one item.
ep_idx_only = self.hf_dataset[idx : idx + 1]["episode_index"][0].item()
query_indices, padding = self._get_query_indices(idx, ep_idx_only)
# 2. Fetch all data (including images)
if query_indices is not None:
# --- Delta path ---
# Fetch all keys (state, action, AND images) for all deltas
item_batch = self.hf_dataset[query_indices["index"]]
# hf_transform_to_torch (item-level) has already run,
# so all values are tensors. We stack them.
item = {}
for key in item_batch:
item[key] = torch.stack(item_batch[key])
item.update(padding)
# Use the "current" item's index/timestamp/ep_idx
# (assuming 'index' is the key for the main query)
current_idx_in_batch = query_indices["index"].index(idx)
item["index"] = item["index"][current_idx_in_batch]
item["timestamp"] = item["timestamp"][current_idx_in_batch]
item["episode_index"] = item["episode_index"][current_idx_in_batch]
item["task_index"] = item["task_index"][current_idx_in_batch]
ep_idx = item["episode_index"].item()
else:
# --- Single-frame path ---
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
# 3. Handle videos (which are always separate)
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
current_ts = (
item["timestamp"].item()
if query_indices is None
else item["timestamp"][current_idx_in_batch].item()
)
video_query_indices = query_indices
if video_query_indices is None:
# If no deltas, create a dummy query_indices for _get_query_timestamps
video_query_indices = {key: [idx] for key in self.meta.video_keys}
query_timestamps = self._get_query_timestamps(current_ts, video_query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
# video_frames are already stacked tensors (B, C, H, W) or (C, H, W)
item = {**video_frames, **item}
# 4. Apply image transforms
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
if cam in item: # videos or images
item[cam] = self.image_transforms(item[cam])
# Add task as a string
# 5. Add task string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
return item
+7 -30
View File
@@ -35,7 +35,6 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
@@ -116,10 +115,15 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
if len(paths) == 0:
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# Convert Path objects to a list of strings
file_paths = [str(path) for path in paths]
# Use datasets.load_dataset to force creation of an efficient cache
# This pre-decodes the images and avoids the on-the-fly bottleneck.
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
dataset = datasets.load_dataset("parquet", data_files=file_paths, features=features, split="train")
return dataset
def get_parquet_num_frames(parquet_path: str | Path) -> int:
@@ -394,33 +398,6 @@ def load_image_as_numpy(
return img_array
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def is_valid_version(version: str) -> bool:
"""Check if a string is a valid PEP 440 version.