mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
try fix 6
This commit is contained in:
@@ -81,23 +81,30 @@ CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. ...
|
||||
[This is the v2.1 item-level transform]
|
||||
"""
|
||||
Converts a single item (row) from a Hugging Face dataset to torch tensors.
|
||||
This item-level transform allows `datasets` to build an efficient, pre-processed cache.
|
||||
"""
|
||||
for key in items_dict:
|
||||
if items_dict[key] is None:
|
||||
item = items_dict[key]
|
||||
|
||||
if item is None:
|
||||
continue
|
||||
if isinstance(items_dict[key], PILImage.Image):
|
||||
# PIL image (h w c) (uint8)
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = to_tensor(items_dict[key])
|
||||
elif isinstance(items_dict[key], str):
|
||||
# keep as is
|
||||
|
||||
if isinstance(item, PILImage.Image):
|
||||
# Correctly transform PIL images to tensors
|
||||
items_dict[key] = transforms.ToTensor()(item)
|
||||
elif isinstance(item, (str, bytes)):
|
||||
# Let strings (like 'task') pass through untouched
|
||||
pass
|
||||
else:
|
||||
# This handles tensors, ints, floats, etc.
|
||||
items_dict[key] = torch.tensor(items_dict[key])
|
||||
# Convert all other numeric types (int, float, list, np.ndarray) to tensors
|
||||
try:
|
||||
items_dict[key] = torch.tensor(item)
|
||||
except Exception as e:
|
||||
# Catch errors like the one you saw
|
||||
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
|
||||
raise e
|
||||
return items_dict
|
||||
|
||||
|
||||
@@ -857,15 +864,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
features = get_hf_features_from_features(self.features)
|
||||
|
||||
# 1. Check if specific episodes are requested by the user.
|
||||
# This is the "data_files" path, which may be slow, but is
|
||||
# necessary for visualization or evaluation on a subset.
|
||||
if self.episodes is not None:
|
||||
# Get the unique set of parquet files for the requested episodes
|
||||
# Path for episode-specific loading (e.g., visualization)
|
||||
fpaths = set()
|
||||
for ep_idx in self.episodes:
|
||||
# Need to read metadata to find the file path for this episode
|
||||
# Use the pre-loaded metadata list
|
||||
ep_meta = self.episodes_metadata_list[ep_idx]
|
||||
chunk_idx = ep_meta["data/chunk_index"]
|
||||
file_idx = ep_meta["data/file_index"]
|
||||
@@ -878,8 +880,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"parquet", data_files=data_files, features=features, split="train"
|
||||
)
|
||||
|
||||
# Filter the loaded dataset to *only* include the requested episodes
|
||||
# This is necessary because v3 files can contain multiple episodes.
|
||||
requested_episodes_set = set(self.episodes)
|
||||
hf_dataset = hf_dataset.filter(
|
||||
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
|
||||
@@ -887,7 +887,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
else:
|
||||
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
|
||||
# We must use `data_dir` to trigger the v2.1-style efficient cache.
|
||||
# Use `data_dir` to trigger the v2.1-style efficient cache.
|
||||
data_dir = str(self.root / "data")
|
||||
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user