From 5ec70f704e15057168f33c5231cbf43f850d64d1 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 18 Jul 2025 16:24:16 +0200 Subject: [PATCH] removed check_timestamps_sync that is no longer used in the code, removed tests in datasets related to check_timestamps_sync added the use of `clear_episode_buffer` that was not used in `save_episode` added the creation of the codebase_version tag that was missing in `slurm_upload` --- .../port_datasets/droid_rlds/slurm_upload.py | 2 + src/lerobot/datasets/lerobot_dataset.py | 21 +-- src/lerobot/datasets/utils.py | 73 --------- tests/datasets/test_delta_timestamps.py | 140 ------------------ 4 files changed, 4 insertions(+), 232 deletions(-) diff --git a/examples/port_datasets/droid_rlds/slurm_upload.py b/examples/port_datasets/droid_rlds/slurm_upload.py index c9d227126..ade1ef874 100644 --- a/examples/port_datasets/droid_rlds/slurm_upload.py +++ b/examples/port_datasets/droid_rlds/slurm_upload.py @@ -90,6 +90,8 @@ class UploadDataset(PipelineStep): ) card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch) + hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + def list_files_recursively(directory): base_path = Path(directory) return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()] diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b4c489426..8e0700db3 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -914,26 +914,9 @@ class LeRobotDataset(torch.utils.data.Dataset): # `meta.save_episode` need to be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata) - # TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index - # ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) - # ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} - # check_timestamps_sync( - # episode_buffer["timestamp"], - # episode_buffer["episode_index"], - # ep_data_index_np, - # self.fps, - # self.tolerance_s, - # ) - - # TODO(rcadene): images are also deleted in clear_episode_buffer - # delete images - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") - if not episode_data: - # Reset episode buffer - self.episode_buffer = self.create_episode_buffer() + # Reset episode buffer and clean up temporary images + self.clear_episode_buffer() def _save_episode_data(self, episode_buffer: dict) -> dict: """Save episode data to a parquet file and update the Hugging Face dataset of frames data. diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 1d1101f59..1151d212e 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -602,79 +602,6 @@ def create_empty_dataset_info( } -def check_timestamps_sync( - timestamps: np.ndarray, - episode_indices: np.ndarray, - episode_data_index: dict[str, np.ndarray], - fps: int, - tolerance_s: float, - raise_value_error: bool = True, -) -> bool: - """ - This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance - to account for possible numerical error. - - Args: - timestamps (np.ndarray): Array of timestamps in seconds. - episode_indices (np.ndarray): Array indicating the episode index for each timestamp. - episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', - which identifies indices for the end of each episode. - fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. - tolerance_s (float): Allowed deviation from the expected (1/fps) difference. - raise_value_error (bool): Whether to raise a ValueError if the check fails. - - Returns: - bool: True if all checked timestamp differences lie within tolerance, False otherwise. - - Raises: - ValueError: If the check fails and `raise_value_error` is True. - """ - if timestamps.shape != episode_indices.shape: - raise ValueError( - "timestamps and episode_indices should have the same shape. " - f"Found {timestamps.shape=} and {episode_indices.shape=}." - ) - - # Consecutive differences - diffs = np.diff(timestamps) - within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s - - # Mask to ignore differences at the boundaries between episodes - mask = np.ones(len(diffs), dtype=bool) - ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode - mask[ignored_diffs] = False - filtered_within_tolerance = within_tolerance[mask] - - # Check if all remaining diffs are within tolerance - if not np.all(filtered_within_tolerance): - # Track original indices before masking - original_indices = np.arange(len(diffs)) - filtered_indices = original_indices[mask] - outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] - outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] - - outside_tolerances = [] - for idx in outside_tolerance_indices: - entry = { - "timestamps": [timestamps[idx], timestamps[idx + 1]], - "diff": diffs[idx], - "episode_index": episode_indices[idx].item() - if hasattr(episode_indices[idx], "item") - else episode_indices[idx], - } - outside_tolerances.append(entry) - - if raise_value_error: - raise ValueError( - f"""One or several timestamps unexpectedly violate the tolerance inside episode range. - This might be due to synchronization issues during data collection. - \n{pformat(outside_tolerances)}""" - ) - return False - - return True - - def check_delta_timestamps( delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True ) -> bool: diff --git a/tests/datasets/test_delta_timestamps.py b/tests/datasets/test_delta_timestamps.py index 786b90ce2..72f69bc72 100644 --- a/tests/datasets/test_delta_timestamps.py +++ b/tests/datasets/test_delta_timestamps.py @@ -11,83 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from itertools import accumulate - -import datasets -import numpy as np -import pyarrow.compute as pc import pytest -import torch from lerobot.datasets.utils import ( check_delta_timestamps, - check_timestamps_sync, get_delta_indices, ) from tests.fixtures.constants import DUMMY_MOTOR_FEATURES -def calculate_total_episode( - hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True -) -> dict[str, torch.Tensor]: - episode_indices = sorted(hf_dataset.unique("episode_index")) - total_episodes = len(episode_indices) - if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): - raise ValueError("episode_index values are not sorted and contiguous.") - return total_episodes - - -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]: - episode_lengths = [] - table = hf_dataset.data.table - total_episodes = calculate_total_episode(hf_dataset) - for ep_idx in range(total_episodes): - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - episode_lengths.insert(ep_idx, len(ep_table)) - - cumulative_lengths = list(accumulate(episode_lengths)) - return { - "from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64), - "to": np.array(cumulative_lengths, dtype=np.int64), - } - - -@pytest.fixture(scope="module") -def synced_timestamps_factory(hf_dataset_factory): - def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - hf_dataset = hf_dataset_factory(fps=fps) - timestamps = torch.stack(hf_dataset["timestamp"]).numpy() - episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() - episode_data_index = calculate_episode_data_index(hf_dataset) - return timestamps, episode_indices, episode_data_index - - return _create_synced_timestamps - - -@pytest.fixture(scope="module") -def unsynced_timestamps_factory(synced_timestamps_factory): - def _create_unsynced_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_unsynced_timestamps - - -@pytest.fixture(scope="module") -def slightly_off_timestamps_factory(synced_timestamps_factory): - def _create_slightly_off_timestamps( - fps: int = 30, tolerance_s: float = 1e-4 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) - timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance - return timestamps, episode_indices, episode_data_index - - return _create_slightly_off_timestamps - - @pytest.fixture(scope="module") def valid_delta_timestamps_factory(): def _create_valid_delta_timestamps( @@ -136,78 +68,6 @@ def delta_indices_factory(): return _delta_indices -def test_check_timestamps_sync_synced(synced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - with pytest.raises(ValueError): - check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - - -def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - raise_value_error=False, - ) - assert result is False - - -def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=ep_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -def test_check_timestamps_sync_single_timestamp(): - fps = 30 - tolerance_s = 1e-4 - timestamps, ep_idx = np.array([0.0]), np.array([0]) - episode_data_index = {"to": np.array([1]), "from": np.array([0])} - result = check_timestamps_sync( - timestamps=timestamps, - episode_indices=ep_idx, - episode_data_index=episode_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - def test_check_delta_timestamps_valid(valid_delta_timestamps_factory): fps = 30 tolerance_s = 1e-4