Compare commits

...

7 Commits

Author SHA1 Message Date
CarolinePascal 34454748f4 fix(datasets) 2026-03-09 15:06:55 +01:00
CarolinePascal 8e5763c5ab fix(datasets) 2026-03-08 20:53:16 +01:00
CarolinePascal 388d4518ba fix(datasets) 2026-03-07 17:57:56 +01:00
CarolinePascal 232dbe4176 fix(datasets) 2026-03-07 17:46:09 +01:00
CarolinePascal 10c2e2fc87 fix(datasets) 2026-03-07 01:14:19 +01:00
CarolinePascal 5e74f06b20 fix(datasets) 2026-03-07 00:24:01 +01:00
CarolinePascal 07931b1101 fix(datasets) 2026-03-07 00:18:57 +01:00
5 changed files with 45 additions and 12 deletions
+4 -4
View File
@@ -486,8 +486,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
timestamps = torch.tensor(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.tensor(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
@@ -667,7 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
query_timestamps[key] = torch.tensor(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
@@ -675,7 +675,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
key: torch.tensor(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
+2 -2
View File
@@ -632,7 +632,7 @@ def cycle(iterable):
iterator = iter(iterable)
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
def create_branch(repo_id, *, branch: str, repo_type: str | None = None, revision: str | None = None) -> None:
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
exists before creating it.
"""
@@ -644,7 +644,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
if ref in refs:
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
api.create_branch(repo_id, repo_type=repo_type, branch=branch, revision=revision)
def create_lerobot_dataset_card(
@@ -105,6 +105,7 @@ import filecmp
import json
import logging
import math
import re
import shutil
import subprocess
import tempfile
@@ -119,6 +120,7 @@ from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
from safetensors.torch import load_file
from lerobot.datasets.backward_compatibility import CompatibilityError
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
@@ -130,6 +132,7 @@ from lerobot.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_repo_versions,
get_safe_version,
load_json,
unflatten_dict,
@@ -205,7 +208,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
def get_features_from_hf_dataset(
dataset: Dataset, robot_config: RobotConfig | None = None
) -> dict[str, list]:
robot_config = parse_robot_config(robot_config)
robot_config = parse_robot_config(robot_config) if robot_config else None
features = {}
for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
@@ -325,7 +328,19 @@ def move_videos(
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
videos_moved = True # Videos have already been moved
assert len(video_files) == total_episodes * len(video_keys)
expected_count = total_episodes * len(video_keys)
if len(video_files) != expected_count:
print(
f"Warning: expected {expected_count} video files "
f"({total_episodes} episodes x {len(video_keys)} keys), "
f"found {len(video_files)}. Keeping only videos matching existing episodes."
)
episode_pattern = re.compile(r"episode_(\d+)")
valid_episodes = set(range(total_episodes))
video_files = [
f for f in video_files
if (m := episode_pattern.search(f)) and int(m.group(1)) in valid_episodes
]
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
@@ -442,8 +457,16 @@ def convert_dataset(
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
try:
v1 = get_safe_version(repo_id, V16)
except CompatibilityError:
hub_versions = get_repo_versions(repo_id)
v1x_versions = [v for v in hub_versions if v.major == 1]
if not v1x_versions:
raise
v1 = f"v{max(v1x_versions)}"
logging.warning(f"v1.6 not found for {repo_id}, falling back to {v1}")
v1x_dir = local_dir / v1 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)
v20_dir.mkdir(parents=True, exist_ok=True)
@@ -455,7 +478,7 @@ def convert_dataset(
branch = "main"
if test_branch:
branch = test_branch
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset", revision=v1)
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
@@ -564,6 +587,12 @@ def convert_dataset(
"features": features,
}
write_json(metadata_v2_0, v20_dir / INFO_PATH)
info = load_json(v20_dir / INFO_PATH)
if "language_instruction" in info.get("features", {}):
del info["features"]["language_instruction"]
write_json(info, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
@@ -677,6 +706,8 @@ def main():
if args.robot is not None:
robot_config = make_robot_config(args.robot)
else:
robot_config = None
del args.robot
@@ -85,7 +85,7 @@ def convert_dataset(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
#hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
@@ -45,6 +45,8 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
if ft["dtype"] in ["image", "video"] and ep_ft_data.ndim == 3:
ep_ft_data = np.expand_dims(ep_ft_data, axis=0)
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim