mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
fixed tensor indicies in _check_cached_episode_sufficient in lerobot_dataset.py, added test
This commit is contained in:
@@ -675,7 +675,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Get available episode indices from cached dataset
|
# Get available episode indices from cached dataset
|
||||||
available_episodes = set(self.hf_dataset["episode_index"])
|
available_episodes = {
|
||||||
|
ep_idx.item() if isinstance(ep_idx, torch.Tensor) else ep_idx
|
||||||
|
for ep_idx in self.hf_dataset["episode_index"]
|
||||||
|
}
|
||||||
|
|
||||||
# Determine requested episodes
|
# Determine requested episodes
|
||||||
if self.episodes is None:
|
if self.episodes is None:
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ from lerobot.datasets.lerobot_dataset import (
|
|||||||
)
|
)
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
create_branch,
|
create_branch,
|
||||||
|
get_hf_features_from_features,
|
||||||
|
hf_transform_to_torch,
|
||||||
hw_to_dataset_features,
|
hw_to_dataset_features,
|
||||||
)
|
)
|
||||||
from lerobot.envs.factory import make_env_config
|
from lerobot.envs.factory import make_env_config
|
||||||
@@ -552,3 +554,103 @@ def test_create_branch():
|
|||||||
|
|
||||||
# Clean
|
# Clean
|
||||||
api.delete_repo(repo_id, repo_type=repo_type)
|
api.delete_repo(repo_id, repo_type=repo_type)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Test the _check_cached_episodes_sufficient method of LeRobotDataset."""
|
||||||
|
# Create a dataset with 5 episodes (0-4)
|
||||||
|
dataset = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test",
|
||||||
|
total_episodes=5,
|
||||||
|
total_frames=200,
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test hf_dataset is None
|
||||||
|
dataset.hf_dataset = None
|
||||||
|
assert dataset._check_cached_episodes_sufficient() is False
|
||||||
|
|
||||||
|
# Test hf_dataset is empty
|
||||||
|
import datasets
|
||||||
|
|
||||||
|
empty_features = get_hf_features_from_features(dataset.features)
|
||||||
|
dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||||
|
{key: [] for key in empty_features}, features=empty_features
|
||||||
|
)
|
||||||
|
dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
assert dataset._check_cached_episodes_sufficient() is False
|
||||||
|
|
||||||
|
# Restore the original dataset for remaining tests
|
||||||
|
dataset.hf_dataset = dataset.load_hf_dataset()
|
||||||
|
|
||||||
|
# Test all episodes requested (self.episodes = None) and all are available
|
||||||
|
dataset.episodes = None
|
||||||
|
assert dataset._check_cached_episodes_sufficient() is True
|
||||||
|
|
||||||
|
# Test specific episodes requested that are all available
|
||||||
|
dataset.episodes = [0, 2, 4]
|
||||||
|
assert dataset._check_cached_episodes_sufficient() is True
|
||||||
|
|
||||||
|
# Test request episodes that don't exist in the cached dataset
|
||||||
|
# Create a dataset with only episodes 0, 1, 2
|
||||||
|
limited_dataset = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "limited",
|
||||||
|
total_episodes=3,
|
||||||
|
total_frames=120,
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request episodes that include non-existent ones
|
||||||
|
limited_dataset.episodes = [0, 1, 2, 3, 4]
|
||||||
|
assert limited_dataset._check_cached_episodes_sufficient() is False
|
||||||
|
|
||||||
|
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
|
||||||
|
# First create the full dataset structure
|
||||||
|
sparse_dataset = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "sparse",
|
||||||
|
total_episodes=5,
|
||||||
|
total_frames=200,
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually filter hf_dataset to only include episodes 0, 2, 4
|
||||||
|
episode_indices = sparse_dataset.hf_dataset["episode_index"]
|
||||||
|
mask = torch.zeros(len(episode_indices), dtype=torch.bool)
|
||||||
|
for ep in [0, 2, 4]:
|
||||||
|
mask |= torch.tensor(episode_indices) == ep
|
||||||
|
|
||||||
|
# Create a filtered dataset
|
||||||
|
filtered_data = {}
|
||||||
|
# Find image keys by checking features
|
||||||
|
image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"]
|
||||||
|
|
||||||
|
for key in sparse_dataset.hf_dataset.column_names:
|
||||||
|
values = sparse_dataset.hf_dataset[key]
|
||||||
|
# Filter values based on mask
|
||||||
|
filtered_values = [val for i, val in enumerate(values) if mask[i]]
|
||||||
|
|
||||||
|
# Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset
|
||||||
|
if key in image_keys and len(filtered_values) > 0:
|
||||||
|
# Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC)
|
||||||
|
filtered_values = [
|
||||||
|
(val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered_data[key] = filtered_values
|
||||||
|
|
||||||
|
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||||
|
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
|
||||||
|
)
|
||||||
|
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
||||||
|
# Test requesting all episodes when only some are cached
|
||||||
|
sparse_dataset.episodes = None
|
||||||
|
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||||
|
|
||||||
|
# Test requesting only the available episodes
|
||||||
|
sparse_dataset.episodes = [0, 2, 4]
|
||||||
|
assert sparse_dataset._check_cached_episodes_sufficient() is True
|
||||||
|
|
||||||
|
# Test requesting a mix of available and unavailable episodes
|
||||||
|
sparse_dataset.episodes = [0, 1, 2]
|
||||||
|
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||||
|
|||||||
Reference in New Issue
Block a user