From ad1ad11eaca61ef2fea3f98dca68756abb38b99e Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 23 Apr 2025 11:42:21 +0200 Subject: [PATCH] fix hf_dataset.set_transform(hf_transform_to_torch) --- tests/datasets/test_delta_timestamps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index c562f64c9..786b90ce2 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -56,8 +56,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n def synced_timestamps_factory(hf_dataset_factory): def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: hf_dataset = hf_dataset_factory(fps=fps) - timestamps = hf_dataset["timestamp"].numpy() - episode_indices = hf_dataset["episode_index"].numpy() + timestamps = torch.stack(hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() episode_data_index = calculate_episode_data_index(hf_dataset) return timestamps, episode_indices, episode_data_index