From c7a3b0262569e2d235428e69177443eeeba84044 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 13 Aug 2025 16:16:32 +0200 Subject: [PATCH] fixed tensor indicies in `_check_cached_episode_sufficient` in lerobot_dataset.py, added test --- src/lerobot/datasets/lerobot_dataset.py | 5 +- tests/datasets/test_datasets.py | 102 ++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 573b01c00..eb81fa531 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -675,7 +675,10 @@ class LeRobotDataset(torch.utils.data.Dataset): return False # Get available episode indices from cached dataset - available_episodes = set(self.hf_dataset["episode_index"]) + available_episodes = { + ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx + for ep_idx in self.hf_dataset["episode_index"] + } # Determine requested episodes if self.episodes is None: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index c798c6a2a..eb740c972 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -36,6 +36,8 @@ from lerobot.datasets.lerobot_dataset import ( ) from lerobot.datasets.utils import ( create_branch, + get_hf_features_from_features, + hf_transform_to_torch, hw_to_dataset_features, ) from lerobot.envs.factory import make_env_config @@ -552,3 +554,103 @@ def test_create_branch(): # Clean api.delete_repo(repo_id, repo_type=repo_type) + + +def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory): + """Test the _check_cached_episodes_sufficient method of LeRobotDataset.""" + # Create a dataset with 5 episodes (0-4) + dataset = lerobot_dataset_factory( + root=tmp_path / "test", + total_episodes=5, + total_frames=200, + use_videos=False, + ) + + # Test hf_dataset is None + dataset.hf_dataset = None + assert dataset._check_cached_episodes_sufficient() is False + + # Test hf_dataset is empty + import datasets + + empty_features = get_hf_features_from_features(dataset.features) + dataset.hf_dataset = datasets.Dataset.from_dict( + {key: [] for key in empty_features}, features=empty_features + ) + dataset.hf_dataset.set_transform(hf_transform_to_torch) + assert dataset._check_cached_episodes_sufficient() is False + + # Restore the original dataset for remaining tests + dataset.hf_dataset = dataset.load_hf_dataset() + + # Test all episodes requested (self.episodes = None) and all are available + dataset.episodes = None + assert dataset._check_cached_episodes_sufficient() is True + + # Test specific episodes requested that are all available + dataset.episodes = [0, 2, 4] + assert dataset._check_cached_episodes_sufficient() is True + + # Test request episodes that don't exist in the cached dataset + # Create a dataset with only episodes 0, 1, 2 + limited_dataset = lerobot_dataset_factory( + root=tmp_path / "limited", + total_episodes=3, + total_frames=120, + use_videos=False, + ) + + # Request episodes that include non-existent ones + limited_dataset.episodes = [0, 1, 2, 3, 4] + assert limited_dataset._check_cached_episodes_sufficient() is False + + # Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4) + # First create the full dataset structure + sparse_dataset = lerobot_dataset_factory( + root=tmp_path / "sparse", + total_episodes=5, + total_frames=200, + use_videos=False, + ) + + # Manually filter hf_dataset to only include episodes 0, 2, 4 + episode_indices = sparse_dataset.hf_dataset["episode_index"] + mask = torch.zeros(len(episode_indices), dtype=torch.bool) + for ep in [0, 2, 4]: + mask |= torch.tensor(episode_indices) == ep + + # Create a filtered dataset + filtered_data = {} + # Find image keys by checking features + image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"] + + for key in sparse_dataset.hf_dataset.column_names: + values = sparse_dataset.hf_dataset[key] + # Filter values based on mask + filtered_values = [val for i, val in enumerate(values) if mask[i]] + + # Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset + if key in image_keys and len(filtered_values) > 0: + # Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC) + filtered_values = [ + (val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values + ] + + filtered_data[key] = filtered_values + + sparse_dataset.hf_dataset = datasets.Dataset.from_dict( + filtered_data, features=get_hf_features_from_features(sparse_dataset.features) + ) + sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch) + + # Test requesting all episodes when only some are cached + sparse_dataset.episodes = None + assert sparse_dataset._check_cached_episodes_sufficient() is False + + # Test requesting only the available episodes + sparse_dataset.episodes = [0, 2, 4] + assert sparse_dataset._check_cached_episodes_sufficient() is True + + # Test requesting a mix of available and unavailable episodes + sparse_dataset.episodes = [0, 1, 2] + assert sparse_dataset._check_cached_episodes_sufficient() is False