diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 885b45582..16127056a 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1043,17 +1043,28 @@ class LeRobotDataset(torch.utils.data.Dataset): def __getitem__(self, idx) -> dict: # Ensure dataset is loaded when we actually need to read from it self._ensure_hf_dataset_loaded() + + # Get the single, current-timestep item item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() query_indices = None if self.delta_indices is not None: + # 1. Get indices for all deltas query_indices, padding = self._get_query_indices(idx, ep_idx) + + # 2. Query non-image, non-video features query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): item[key] = val + # 3. Query image features (which are not in _query_hf_dataset) + for key in self.meta.image_keys: + if key in query_indices: + # hf_dataset[query_indices[key]][key] returns a LIST of PIL.Image objects + item[key] = torch.stack(self.hf_dataset[query_indices[key]][key]) + if len(self.meta.video_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices)