diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 83d452a44..8fa4f200b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Check if cached dataset contains all requested episodes if not self._check_cached_episodes_sufficient(): raise FileNotFoundError("Cached dataset doesn't contain all requested episodes") - except (AssertionError, FileNotFoundError, NotADirectoryError): + except (FileNotFoundError, NotADirectoryError): if is_valid_version(self.revision): self.revision = get_safe_version(self.repo_id, self.revision) self.download(download_videos) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index acc24a9e0..8c8494b87 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -227,16 +227,17 @@ def decode_video_frames_torchvision( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - f"\nbackend: {backend}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + f"\nbackend: {backend}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) @@ -248,7 +249,11 @@ def decode_video_frames_torchvision( # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 - assert len(timestamps) == len(closest_frames) + if len(timestamps) != len(closest_frames): + raise FrameTimestampError( + f"Number of retrieved frames ({len(closest_frames)}) does not match " + f"number of queried timestamps ({len(timestamps)})" + ) return closest_frames @@ -353,15 +358,16 @@ def decode_video_frames_torchcodec( min_, argmin_ = dist.min(1) is_within_tol = min_ < tolerance_s - assert is_within_tol.all(), ( - f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." - "It means that the closest frame that can be loaded from the video is too far away in time." - "This might be due to synchronization issues with timestamps during data collection." - "To be safe, we advise to ignore this item during training." - f"\nqueried timestamps: {query_ts}" - f"\nloaded timestamps: {loaded_ts}" - f"\nvideo: {video_path}" - ) + if not is_within_tol.all(): + raise FrameTimestampError( + f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." + " It means that the closest frame that can be loaded from the video is too far away in time." + " This might be due to synchronization issues with timestamps during data collection." + " To be safe, we advise to ignore this item during training." + f"\nqueried timestamps: {query_ts}" + f"\nloaded timestamps: {loaded_ts}" + f"\nvideo: {video_path}" + ) # get closest frames to the query timestamps closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])