mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
try fix 6
This commit is contained in:
@@ -81,23 +81,30 @@ CODEBASE_VERSION = "v3.0"
|
|||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
|
def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]:
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""
|
||||||
to torch tensors. ...
|
Converts a single item (row) from a Hugging Face dataset to torch tensors.
|
||||||
[This is the v2.1 item-level transform]
|
This item-level transform allows `datasets` to build an efficient, pre-processed cache.
|
||||||
"""
|
"""
|
||||||
for key in items_dict:
|
for key in items_dict:
|
||||||
if items_dict[key] is None:
|
item = items_dict[key]
|
||||||
|
|
||||||
|
if item is None:
|
||||||
continue
|
continue
|
||||||
if isinstance(items_dict[key], PILImage.Image):
|
|
||||||
# PIL image (h w c) (uint8)
|
if isinstance(item, PILImage.Image):
|
||||||
to_tensor = transforms.ToTensor()
|
# Correctly transform PIL images to tensors
|
||||||
items_dict[key] = to_tensor(items_dict[key])
|
items_dict[key] = transforms.ToTensor()(item)
|
||||||
elif isinstance(items_dict[key], str):
|
elif isinstance(item, (str, bytes)):
|
||||||
# keep as is
|
# Let strings (like 'task') pass through untouched
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# This handles tensors, ints, floats, etc.
|
# Convert all other numeric types (int, float, list, np.ndarray) to tensors
|
||||||
items_dict[key] = torch.tensor(items_dict[key])
|
try:
|
||||||
|
items_dict[key] = torch.tensor(item)
|
||||||
|
except Exception as e:
|
||||||
|
# Catch errors like the one you saw
|
||||||
|
print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}")
|
||||||
|
raise e
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -857,15 +864,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
|
|
||||||
# 1. Check if specific episodes are requested by the user.
|
|
||||||
# This is the "data_files" path, which may be slow, but is
|
|
||||||
# necessary for visualization or evaluation on a subset.
|
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
# Get the unique set of parquet files for the requested episodes
|
# Path for episode-specific loading (e.g., visualization)
|
||||||
fpaths = set()
|
fpaths = set()
|
||||||
for ep_idx in self.episodes:
|
for ep_idx in self.episodes:
|
||||||
# Need to read metadata to find the file path for this episode
|
|
||||||
# Use the pre-loaded metadata list
|
|
||||||
ep_meta = self.episodes_metadata_list[ep_idx]
|
ep_meta = self.episodes_metadata_list[ep_idx]
|
||||||
chunk_idx = ep_meta["data/chunk_index"]
|
chunk_idx = ep_meta["data/chunk_index"]
|
||||||
file_idx = ep_meta["data/file_index"]
|
file_idx = ep_meta["data/file_index"]
|
||||||
@@ -878,8 +880,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
"parquet", data_files=data_files, features=features, split="train"
|
"parquet", data_files=data_files, features=features, split="train"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filter the loaded dataset to *only* include the requested episodes
|
|
||||||
# This is necessary because v3 files can contain multiple episodes.
|
|
||||||
requested_episodes_set = set(self.episodes)
|
requested_episodes_set = set(self.episodes)
|
||||||
hf_dataset = hf_dataset.filter(
|
hf_dataset = hf_dataset.filter(
|
||||||
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
|
lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000
|
||||||
@@ -887,7 +887,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
|
# THIS IS THE FAST PATH FOR TRAINING (self.episodes is None)
|
||||||
# We must use `data_dir` to trigger the v2.1-style efficient cache.
|
# Use `data_dir` to trigger the v2.1-style efficient cache.
|
||||||
data_dir = str(self.root / "data")
|
data_dir = str(self.root / "data")
|
||||||
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
|
hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user