mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
removed check_timestamps_sync that is no longer used in the code,
removed tests in datasets related to check_timestamps_sync added the use of `clear_episode_buffer` that was not used in `save_episode` added the creation of the codebase_version tag that was missing in `slurm_upload`
This commit is contained in:
@@ -90,6 +90,8 @@ class UploadDataset(PipelineStep):
|
|||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
|
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
|
||||||
|
|
||||||
|
hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
|
||||||
def list_files_recursively(directory):
|
def list_files_recursively(directory):
|
||||||
base_path = Path(directory)
|
base_path = Path(directory)
|
||||||
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
|
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
|
||||||
|
|||||||
@@ -914,26 +914,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# `meta.save_episode` need to be executed after encoding the videos
|
# `meta.save_episode` need to be executed after encoding the videos
|
||||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||||
|
|
||||||
# TODO(rcadene): remove? there is only one episode in the episode buffer, no need for ep_data_index
|
|
||||||
# ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
|
||||||
# ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
|
||||||
# check_timestamps_sync(
|
|
||||||
# episode_buffer["timestamp"],
|
|
||||||
# episode_buffer["episode_index"],
|
|
||||||
# ep_data_index_np,
|
|
||||||
# self.fps,
|
|
||||||
# self.tolerance_s,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# TODO(rcadene): images are also deleted in clear_episode_buffer
|
|
||||||
# delete images
|
|
||||||
img_dir = self.root / "images"
|
|
||||||
if img_dir.is_dir():
|
|
||||||
shutil.rmtree(self.root / "images")
|
|
||||||
|
|
||||||
if not episode_data:
|
if not episode_data:
|
||||||
# Reset episode buffer
|
# Reset episode buffer and clean up temporary images
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.clear_episode_buffer()
|
||||||
|
|
||||||
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
def _save_episode_data(self, episode_buffer: dict) -> dict:
|
||||||
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
"""Save episode data to a parquet file and update the Hugging Face dataset of frames data.
|
||||||
|
|||||||
@@ -602,79 +602,6 @@ def create_empty_dataset_info(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_timestamps_sync(
|
|
||||||
timestamps: np.ndarray,
|
|
||||||
episode_indices: np.ndarray,
|
|
||||||
episode_data_index: dict[str, np.ndarray],
|
|
||||||
fps: int,
|
|
||||||
tolerance_s: float,
|
|
||||||
raise_value_error: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
|
|
||||||
to account for possible numerical error.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timestamps (np.ndarray): Array of timestamps in seconds.
|
|
||||||
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
|
|
||||||
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
|
|
||||||
which identifies indices for the end of each episode.
|
|
||||||
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
|
|
||||||
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
|
|
||||||
raise_value_error (bool): Whether to raise a ValueError if the check fails.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the check fails and `raise_value_error` is True.
|
|
||||||
"""
|
|
||||||
if timestamps.shape != episode_indices.shape:
|
|
||||||
raise ValueError(
|
|
||||||
"timestamps and episode_indices should have the same shape. "
|
|
||||||
f"Found {timestamps.shape=} and {episode_indices.shape=}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Consecutive differences
|
|
||||||
diffs = np.diff(timestamps)
|
|
||||||
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
|
|
||||||
|
|
||||||
# Mask to ignore differences at the boundaries between episodes
|
|
||||||
mask = np.ones(len(diffs), dtype=bool)
|
|
||||||
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
|
|
||||||
mask[ignored_diffs] = False
|
|
||||||
filtered_within_tolerance = within_tolerance[mask]
|
|
||||||
|
|
||||||
# Check if all remaining diffs are within tolerance
|
|
||||||
if not np.all(filtered_within_tolerance):
|
|
||||||
# Track original indices before masking
|
|
||||||
original_indices = np.arange(len(diffs))
|
|
||||||
filtered_indices = original_indices[mask]
|
|
||||||
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
|
|
||||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
|
||||||
|
|
||||||
outside_tolerances = []
|
|
||||||
for idx in outside_tolerance_indices:
|
|
||||||
entry = {
|
|
||||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
|
||||||
"diff": diffs[idx],
|
|
||||||
"episode_index": episode_indices[idx].item()
|
|
||||||
if hasattr(episode_indices[idx], "item")
|
|
||||||
else episode_indices[idx],
|
|
||||||
}
|
|
||||||
outside_tolerances.append(entry)
|
|
||||||
|
|
||||||
if raise_value_error:
|
|
||||||
raise ValueError(
|
|
||||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
|
||||||
This might be due to synchronization issues during data collection.
|
|
||||||
\n{pformat(outside_tolerances)}"""
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def check_delta_timestamps(
|
def check_delta_timestamps(
|
||||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|||||||
@@ -11,83 +11,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from itertools import accumulate
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import numpy as np
|
|
||||||
import pyarrow.compute as pc
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
)
|
)
|
||||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||||
|
|
||||||
|
|
||||||
def calculate_total_episode(
|
|
||||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
|
||||||
) -> dict[str, torch.Tensor]:
|
|
||||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
|
||||||
total_episodes = len(episode_indices)
|
|
||||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
|
||||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
|
||||||
return total_episodes
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
|
|
||||||
episode_lengths = []
|
|
||||||
table = hf_dataset.data.table
|
|
||||||
total_episodes = calculate_total_episode(hf_dataset)
|
|
||||||
for ep_idx in range(total_episodes):
|
|
||||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
|
||||||
episode_lengths.insert(ep_idx, len(ep_table))
|
|
||||||
|
|
||||||
cumulative_lengths = list(accumulate(episode_lengths))
|
|
||||||
return {
|
|
||||||
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
|
|
||||||
"to": np.array(cumulative_lengths, dtype=np.int64),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def synced_timestamps_factory(hf_dataset_factory):
|
|
||||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
hf_dataset = hf_dataset_factory(fps=fps)
|
|
||||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
|
||||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_synced_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
|
||||||
def _create_unsynced_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
|
||||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_unsynced_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
|
||||||
def _create_slightly_off_timestamps(
|
|
||||||
fps: int = 30, tolerance_s: float = 1e-4
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
|
||||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
|
||||||
return timestamps, episode_indices, episode_data_index
|
|
||||||
|
|
||||||
return _create_slightly_off_timestamps
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def valid_delta_timestamps_factory():
|
def valid_delta_timestamps_factory():
|
||||||
def _create_valid_delta_timestamps(
|
def _create_valid_delta_timestamps(
|
||||||
@@ -136,78 +68,6 @@ def delta_indices_factory():
|
|||||||
return _delta_indices
|
return _delta_indices
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_synced(synced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
raise_value_error=False,
|
|
||||||
)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=ep_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_timestamps_sync_single_timestamp():
|
|
||||||
fps = 30
|
|
||||||
tolerance_s = 1e-4
|
|
||||||
timestamps, ep_idx = np.array([0.0]), np.array([0])
|
|
||||||
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
|
|
||||||
result = check_timestamps_sync(
|
|
||||||
timestamps=timestamps,
|
|
||||||
episode_indices=ep_idx,
|
|
||||||
episode_data_index=episode_data_index,
|
|
||||||
fps=fps,
|
|
||||||
tolerance_s=tolerance_s,
|
|
||||||
)
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||||
fps = 30
|
fps = 30
|
||||||
tolerance_s = 1e-4
|
tolerance_s = 1e-4
|
||||||
|
|||||||
Reference in New Issue
Block a user