diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 420117996..6798e7fd7 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -935,17 +935,30 @@ class LeRobotDataset(torch.utils.data.Dataset): else: return get_hf_features_from_features(self.features) - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps. + + Args: + abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). + ep_idx: The episode index. + + Returns: + A tuple of (query_indices, padding) where: + - query_indices: Dict mapping keys to lists of absolute indices to query + - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions + """ ep = self.meta.episodes[ep_idx] ep_start = ep["dataset_from_index"] ep_end = ep["dataset_to_index"] query_indices = { - key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx] + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } @@ -1037,10 +1050,12 @@ class LeRobotDataset(torch.utils.data.Dataset): self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() + # Use the absolute index from the dataset for delta timestamp calculations + abs_idx = item["index"].item() query_indices = None if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(idx, ep_idx) + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 157a14851..27c51b3c4 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1451,3 +1451,202 @@ def test_valid_video_codecs_constant(): assert "hevc" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS assert len(VALID_VIDEO_CODECS) == 3 + + +def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter. + + The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset + instead of the absolute index when comparing against episode boundaries (ep_start, ep_end). + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Create 3 episodes with 10 frames each + frames_per_episode = 10 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 (middle episode) with delta_timestamps + delta_ts = {"observation.state": [0.0]} # Just the current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + ) + + # Verify the filtered dataset has the correct length + assert len(filtered_dataset) == frames_per_episode + + # Check that no frames are marked as padded (since delta=0 should always be valid) + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded" + # Verify we're getting data from episode 1 + assert frame["episode_index"].item() == 1 + + +def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 3 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 with delta_timestamps that go beyond episode boundaries + # fps=10, so 0.1s = 1 frame offset + delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, # Slightly less than half a frame at 10fps + ) + + assert len(filtered_dataset) == frames_per_episode + + # Check padding at the start of the episode (first frame) + first_frame = filtered_dataset[0] + is_pad = first_frame["observation.state_is_pad"].tolist() + # At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not + assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}" + + # Check middle frame (no padding expected) + mid_frame = filtered_dataset[2] + is_pad = mid_frame["observation.state_is_pad"].tolist() + assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}" + + # Check padding at the end of the episode (last frame) + last_frame = filtered_dataset[4] + is_pad = last_frame["observation.state_is_pad"].tolist() + # At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be + assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}" + + +def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Test delta_timestamps with multiple non-consecutive episodes selected.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 5 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(5): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episodes 1 and 3 (non-consecutive) + delta_ts = {"observation.state": [0.0]} + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1, 3], + delta_timestamps=delta_ts, + ) + + assert len(filtered_dataset) == 2 * frames_per_episode + + # All frames should have valid (non-padded) data for delta=0 + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False + + # Verify we're getting the correct episodes + episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))] + expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode + assert episode_indices == expected_episodes + + +def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps returns the correct observation values, not just correct padding.""" + features = { + "observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 2 episodes with known values + # Episode 0: frames with values 0, 1, 2, 3, 4 + # Episode 1: frames with values 10, 11, 12, 13, 14 + frames_per_episode = 5 + for ep_idx in range(2): + for frame_idx in range(frames_per_episode): + value = ep_idx * 10 + frame_idx + dataset.add_frame( + { + "observation.state": torch.tensor([value], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episode 1 with delta that looks at previous frame + delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, + ) + + # Check frame 2 of episode 1 (which has absolute index 7, value 12) + frame = filtered_dataset[2] + state_values = frame["observation.state"].tolist() + # Should get [11, 12] - the previous and current values within episode 1 + assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}" + + # Check first frame - previous frame should be clamped to episode start (padded) + first_frame = filtered_dataset[0] + state_values = first_frame["observation.state"].tolist() + is_pad = first_frame["observation.state_is_pad"].tolist() + # Previous frame is outside episode, so it's clamped to first frame and marked as padded + assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}" + assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"