try fix 3

This commit is contained in:
Steven Palma
2025-11-05 21:08:23 +01:00
parent 550866a3c5
commit a19bd6e84d
2 changed files with 73 additions and 31 deletions
+73 -3
View File
@@ -19,6 +19,7 @@ import shutil
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
import datasets
import numpy as np
@@ -31,6 +32,8 @@ import torch
import torch.utils
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.datasets.image_writer import AsyncImageWriter, write_image
@@ -50,11 +53,9 @@ from lerobot.datasets.utils import (
get_file_size_in_mb,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_info,
load_nested_dataset,
load_stats,
load_tasks,
update_chunk_file_indices,
@@ -832,8 +833,50 @@ class LeRobotDataset(torch.utils.data.Dataset):
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
# We MUST import this here to avoid circular dependency
# (utils imports lerobot_dataset for backward_compatibility)
from lerobot.datasets.utils import hf_transform_to_torch
features = get_hf_features_from_features(self.features)
hf_dataset = load_nested_dataset(self.root / "data", features=features)
# This is the v2.1 logic that forces an efficient, pre-decoded cache build.
# This is the key to performance for dtype="image" datasets.
# 1. Check if specific episodes are requested by the user.
if self.episodes is not None:
# Get the unique set of parquet files for the requested episodes
fpaths = set()
for ep_idx in self.episodes:
# Need to read metadata to find the file path for this episode
# We use self.meta.episodes (the loaded dataset) here
ep_meta = self.meta.episodes[ep_idx]
chunk_idx = ep_meta["data/chunk_index"]
file_idx = ep_meta["data/file_index"]
fpath_str = self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
fpaths.add(str(self.root / fpath_str))
data_files = sorted(list(fpaths))
hf_dataset = datasets.load_dataset(
"parquet", data_files=data_files, features=features, split="train"
)
# Filter the loaded dataset to *only* include the requested episodes
# This is necessary because the v3 files contain multiple episodes.
requested_episodes_set = set(self.episodes)
hf_dataset = hf_dataset.filter(
lambda x: x["episode_index"] in requested_episodes_set,
batched=True, # Use batched=True for faster filtering
batch_size=1000,
)
else:
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
# Load all data files using data_dir, which is the most efficient.
data_dir = str(self.root / "data")
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@@ -1675,3 +1718,30 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Transformations: {self.image_transforms},\n"
f")"
)
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
-28
View File
@@ -35,7 +35,6 @@ from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.backward_compatibility import (
@@ -399,33 +398,6 @@ def load_image_as_numpy(
return img_array
def hf_transform_to_torch(items_dict: dict[str, list[Any]]) -> dict[str, list[torch.Tensor | str]]:
"""Convert a batch from a Hugging Face dataset to torch tensors.
This transform function converts items from Hugging Face dataset format (pyarrow)
to torch tensors. Importantly, images are converted from PIL objects (H, W, C, uint8)
to a torch image representation (C, H, W, float32) in the range [0, 1]. Other
types are converted to torch.tensor.
Args:
items_dict (dict): A dictionary representing a batch of data from a
Hugging Face dataset.
Returns:
dict: The batch with items converted to torch tensors.
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def is_valid_version(version: str) -> bool:
"""Check if a string is a valid PEP 440 version.