mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
fixed tensor indicies in _check_cached_episode_sufficient in lerobot_dataset.py, added test
This commit is contained in:
@@ -36,6 +36,8 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
create_branch,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.envs.factory import make_env_config
|
||||
@@ -552,3 +554,103 @@ def test_create_branch():
|
||||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
"""Test the _check_cached_episodes_sufficient method of LeRobotDataset."""
|
||||
# Create a dataset with 5 episodes (0-4)
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "test",
|
||||
total_episodes=5,
|
||||
total_frames=200,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# Test hf_dataset is None
|
||||
dataset.hf_dataset = None
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test hf_dataset is empty
|
||||
import datasets
|
||||
|
||||
empty_features = get_hf_features_from_features(dataset.features)
|
||||
dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
{key: [] for key in empty_features}, features=empty_features
|
||||
)
|
||||
dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Restore the original dataset for remaining tests
|
||||
dataset.hf_dataset = dataset.load_hf_dataset()
|
||||
|
||||
# Test all episodes requested (self.episodes = None) and all are available
|
||||
dataset.episodes = None
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test specific episodes requested that are all available
|
||||
dataset.episodes = [0, 2, 4]
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test request episodes that don't exist in the cached dataset
|
||||
# Create a dataset with only episodes 0, 1, 2
|
||||
limited_dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "limited",
|
||||
total_episodes=3,
|
||||
total_frames=120,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# Request episodes that include non-existent ones
|
||||
limited_dataset.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
|
||||
# First create the full dataset structure
|
||||
sparse_dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "sparse",
|
||||
total_episodes=5,
|
||||
total_frames=200,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# Manually filter hf_dataset to only include episodes 0, 2, 4
|
||||
episode_indices = sparse_dataset.hf_dataset["episode_index"]
|
||||
mask = torch.zeros(len(episode_indices), dtype=torch.bool)
|
||||
for ep in [0, 2, 4]:
|
||||
mask |= torch.tensor(episode_indices) == ep
|
||||
|
||||
# Create a filtered dataset
|
||||
filtered_data = {}
|
||||
# Find image keys by checking features
|
||||
image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"]
|
||||
|
||||
for key in sparse_dataset.hf_dataset.column_names:
|
||||
values = sparse_dataset.hf_dataset[key]
|
||||
# Filter values based on mask
|
||||
filtered_values = [val for i, val in enumerate(values) if mask[i]]
|
||||
|
||||
# Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset
|
||||
if key in image_keys and len(filtered_values) > 0:
|
||||
# Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC)
|
||||
filtered_values = [
|
||||
(val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values
|
||||
]
|
||||
|
||||
filtered_data[key] = filtered_values
|
||||
|
||||
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
|
||||
)
|
||||
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
# Test requesting all episodes when only some are cached
|
||||
sparse_dataset.episodes = None
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test requesting only the available episodes
|
||||
sparse_dataset.episodes = [0, 2, 4]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test requesting a mix of available and unavailable episodes
|
||||
sparse_dataset.episodes = [0, 1, 2]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
Reference in New Issue
Block a user