diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 99e210220..885b45582 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -81,23 +81,30 @@ CODEBASE_VERSION = "v3.0" def hf_transform_to_torch(items_dict: dict[str, Any]) -> dict[str, torch.Tensor | str]: - """Get a transform function that convert items from Hugging Face dataset (pyarrow) - to torch tensors. ... - [This is the v2.1 item-level transform] + """ + Converts a single item (row) from a Hugging Face dataset to torch tensors. + This item-level transform allows `datasets` to build an efficient, pre-processed cache. """ for key in items_dict: - if items_dict[key] is None: + item = items_dict[key] + + if item is None: continue - if isinstance(items_dict[key], PILImage.Image): - # PIL image (h w c) (uint8) - to_tensor = transforms.ToTensor() - items_dict[key] = to_tensor(items_dict[key]) - elif isinstance(items_dict[key], str): - # keep as is + + if isinstance(item, PILImage.Image): + # Correctly transform PIL images to tensors + items_dict[key] = transforms.ToTensor()(item) + elif isinstance(item, (str, bytes)): + # Let strings (like 'task') pass through untouched pass else: - # This handles tensors, ints, floats, etc. - items_dict[key] = torch.tensor(items_dict[key]) + # Convert all other numeric types (int, float, list, np.ndarray) to tensors + try: + items_dict[key] = torch.tensor(item) + except Exception as e: + # Catch errors like the one you saw + print(f"Error converting item['{key}'] to tensor. Value: {item}, Type: {type(item)}") + raise e return items_dict @@ -857,15 +864,10 @@ class LeRobotDataset(torch.utils.data.Dataset): features = get_hf_features_from_features(self.features) - # 1. Check if specific episodes are requested by the user. - # This is the "data_files" path, which may be slow, but is - # necessary for visualization or evaluation on a subset. if self.episodes is not None: - # Get the unique set of parquet files for the requested episodes + # Path for episode-specific loading (e.g., visualization) fpaths = set() for ep_idx in self.episodes: - # Need to read metadata to find the file path for this episode - # Use the pre-loaded metadata list ep_meta = self.episodes_metadata_list[ep_idx] chunk_idx = ep_meta["data/chunk_index"] file_idx = ep_meta["data/file_index"] @@ -878,8 +880,6 @@ class LeRobotDataset(torch.utils.data.Dataset): "parquet", data_files=data_files, features=features, split="train" ) - # Filter the loaded dataset to *only* include the requested episodes - # This is necessary because v3 files can contain multiple episodes. requested_episodes_set = set(self.episodes) hf_dataset = hf_dataset.filter( lambda x: x["episode_index"] in requested_episodes_set, batched=True, batch_size=1000 @@ -887,7 +887,7 @@ class LeRobotDataset(torch.utils.data.Dataset): else: # THIS IS THE FAST PATH FOR TRAINING (self.episodes is None) - # We must use `data_dir` to trigger the v2.1-style efficient cache. + # Use `data_dir` to trigger the v2.1-style efficient cache. data_dir = str(self.root / "data") hf_dataset = datasets.load_dataset("parquet", data_dir=data_dir, features=features, split="train")