mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor(datasets): module cleanup (#3169)
This commit is contained in:
@@ -19,11 +19,26 @@ import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import combine_feature_dicts, create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: Dataset) -> dict[str, torch.Tensor]:
|
||||
"""Calculate episode data index for testing. Returns {"from": Tensor, "to": Tensor}."""
|
||||
episode_data_index: dict[str, list[int]] = {"from": [], "to": []}
|
||||
current_episode = None
|
||||
if len(hf_dataset) == 0:
|
||||
return {"from": torch.tensor([]), "to": torch.tensor([])}
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
episode_data_index["from"].append(idx)
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
current_episode = episode_idx
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
return {k: torch.tensor(v) for k, v in episode_data_index.items()}
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
|
||||
Reference in New Issue
Block a user