mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
🐛 fix v30→v21 converter imports (#101)
Co-authored-by: Codex <codex@openai.com>
This commit is contained in:
@@ -32,6 +32,11 @@ import pyarrow.parquet as pq
|
|||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from lerobot.datasets.io_utils import (
|
||||||
|
load_info,
|
||||||
|
load_tasks,
|
||||||
|
write_info,
|
||||||
|
)
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
@@ -40,11 +45,8 @@ from lerobot.datasets.utils import (
|
|||||||
LEGACY_EPISODES_PATH,
|
LEGACY_EPISODES_PATH,
|
||||||
LEGACY_EPISODES_STATS_PATH,
|
LEGACY_EPISODES_STATS_PATH,
|
||||||
LEGACY_TASKS_PATH,
|
LEGACY_TASKS_PATH,
|
||||||
load_info,
|
|
||||||
load_tasks,
|
|
||||||
serialize_dict,
|
serialize_dict,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
write_info,
|
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
@@ -52,8 +54,12 @@ from lerobot.utils.utils import init_logging
|
|||||||
V21 = "v2.1"
|
V21 = "v2.1"
|
||||||
V30 = "v3.0"
|
V30 = "v3.0"
|
||||||
|
|
||||||
LEGACY_DATA_PATH_TEMPLATE = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
LEGACY_DATA_PATH_TEMPLATE = (
|
||||||
LEGACY_VIDEO_PATH_TEMPLATE = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||||
|
)
|
||||||
|
LEGACY_VIDEO_PATH_TEMPLATE = (
|
||||||
|
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
|
)
|
||||||
MIN_VIDEO_DURATION = 1e-6
|
MIN_VIDEO_DURATION = 1e-6
|
||||||
LEGACY_STATS_KEYS = ("mean", "std", "min", "max", "count")
|
LEGACY_STATS_KEYS = ("mean", "std", "min", "max", "count")
|
||||||
|
|
||||||
@@ -137,7 +143,9 @@ def convert_info(
|
|||||||
if ft.get("dtype") != "video":
|
if ft.get("dtype") != "video":
|
||||||
ft.pop("fps", None)
|
ft.pop("fps", None)
|
||||||
|
|
||||||
info["total_chunks"] = math.ceil(total_episodes / chunks_size) if total_episodes > 0 else 0
|
info["total_chunks"] = (
|
||||||
|
math.ceil(total_episodes / chunks_size) if total_episodes > 0 else 0
|
||||||
|
)
|
||||||
info["total_videos"] = total_episodes * len(video_keys)
|
info["total_videos"] = total_episodes * len(video_keys)
|
||||||
|
|
||||||
write_info(info, new_root)
|
write_info(info, new_root)
|
||||||
@@ -156,14 +164,22 @@ def _group_episodes_by_data_file(
|
|||||||
return grouped
|
return grouped
|
||||||
|
|
||||||
|
|
||||||
def convert_data(root: Path, new_root: Path, episode_records: list[dict[str, Any]]) -> None:
|
def convert_data(
|
||||||
|
root: Path, new_root: Path, episode_records: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
logging.info("Converting consolidated parquet files back to per-episode files")
|
logging.info("Converting consolidated parquet files back to per-episode files")
|
||||||
grouped = _group_episodes_by_data_file(episode_records)
|
grouped = _group_episodes_by_data_file(episode_records)
|
||||||
|
|
||||||
for (chunk_idx, file_idx), records in tqdm.tqdm(grouped.items(), desc="convert data files"):
|
for (chunk_idx, file_idx), records in tqdm.tqdm(
|
||||||
source_path = root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
grouped.items(), desc="convert data files"
|
||||||
|
):
|
||||||
|
source_path = root / DEFAULT_DATA_PATH.format(
|
||||||
|
chunk_index=chunk_idx, file_index=file_idx
|
||||||
|
)
|
||||||
if not source_path.exists():
|
if not source_path.exists():
|
||||||
raise FileNotFoundError(f"Expected source parquet file not found: {source_path}")
|
raise FileNotFoundError(
|
||||||
|
f"Expected source parquet file not found: {source_path}"
|
||||||
|
)
|
||||||
|
|
||||||
table = pq.read_table(source_path)
|
table = pq.read_table(source_path)
|
||||||
records = sorted(records, key=lambda rec: int(rec["dataset_from_index"]))
|
records = sorted(records, key=lambda rec: int(rec["dataset_from_index"]))
|
||||||
@@ -235,10 +251,14 @@ def _validate_video_paths(src: Path, dst: Path) -> None:
|
|||||||
# Validate file extensions for video files
|
# Validate file extensions for video files
|
||||||
valid_video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".m4v"}
|
valid_video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".m4v"}
|
||||||
if src_resolved.suffix.lower() not in valid_video_extensions:
|
if src_resolved.suffix.lower() not in valid_video_extensions:
|
||||||
raise ValueError(f"Source file does not have a valid video extension: {src_resolved}")
|
raise ValueError(
|
||||||
|
f"Source file does not have a valid video extension: {src_resolved}"
|
||||||
|
)
|
||||||
|
|
||||||
if dst_resolved.suffix.lower() not in valid_video_extensions:
|
if dst_resolved.suffix.lower() not in valid_video_extensions:
|
||||||
raise ValueError(f"Destination file does not have a valid video extension: {dst_resolved}")
|
raise ValueError(
|
||||||
|
f"Destination file does not have a valid video extension: {dst_resolved}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check for path traversal attempts in the original paths
|
# Check for path traversal attempts in the original paths
|
||||||
src_str = str(src)
|
src_str = str(src)
|
||||||
@@ -253,11 +273,16 @@ def _validate_video_paths(src: Path, dst: Path) -> None:
|
|||||||
|
|
||||||
# Additional check: ensure resolved paths don't point to system directories
|
# Additional check: ensure resolved paths don't point to system directories
|
||||||
system_dirs = {"/etc", "/sys", "/proc", "/dev", "/boot", "/root"}
|
system_dirs = {"/etc", "/sys", "/proc", "/dev", "/boot", "/root"}
|
||||||
for resolved_path, name in [(src_resolved, "source"), (dst_resolved, "destination")]:
|
for resolved_path, name in [
|
||||||
|
(src_resolved, "source"),
|
||||||
|
(dst_resolved, "destination"),
|
||||||
|
]:
|
||||||
path_str = str(resolved_path)
|
path_str = str(resolved_path)
|
||||||
for sys_dir in system_dirs:
|
for sys_dir in system_dirs:
|
||||||
if path_str.startswith(sys_dir + "/") or path_str == sys_dir:
|
if path_str.startswith(sys_dir + "/") or path_str == sys_dir:
|
||||||
raise ValueError(f"Path points to system directory: {name} path {resolved_path}")
|
raise ValueError(
|
||||||
|
f"Path points to system directory: {name} path {resolved_path}"
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure the destination directory can be created safely
|
# Ensure the destination directory can be created safely
|
||||||
try:
|
try:
|
||||||
@@ -324,9 +349,13 @@ def _extract_video_segment(
|
|||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
except subprocess.TimeoutExpired as exc:
|
except subprocess.TimeoutExpired as exc:
|
||||||
raise RuntimeError(f"ffmpeg timed out while processing video '{src}' -> '{dst}'") from exc
|
raise RuntimeError(
|
||||||
|
f"ffmpeg timed out while processing video '{src}' -> '{dst}'"
|
||||||
|
) from exc
|
||||||
except FileNotFoundError as exc:
|
except FileNotFoundError as exc:
|
||||||
raise RuntimeError("ffmpeg executable not found; it is required for video conversion") from exc
|
raise RuntimeError(
|
||||||
|
"ffmpeg executable not found; it is required for video conversion"
|
||||||
|
) from exc
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
error_msg = f"ffmpeg failed while splitting video '{src}' into '{dst}'"
|
error_msg = f"ffmpeg failed while splitting video '{src}' into '{dst}'"
|
||||||
if exc.stderr:
|
if exc.stderr:
|
||||||
@@ -334,7 +363,12 @@ def _extract_video_segment(
|
|||||||
raise RuntimeError(error_msg) from exc
|
raise RuntimeError(error_msg) from exc
|
||||||
|
|
||||||
|
|
||||||
def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, Any]], video_keys: list[str]) -> None:
|
def convert_videos(
|
||||||
|
root: Path,
|
||||||
|
new_root: Path,
|
||||||
|
episode_records: list[dict[str, Any]],
|
||||||
|
video_keys: list[str],
|
||||||
|
) -> None:
|
||||||
if len(video_keys) == 0:
|
if len(video_keys) == 0:
|
||||||
logging.info("No video features detected; skipping video conversion")
|
logging.info("No video features detected; skipping video conversion")
|
||||||
return
|
return
|
||||||
@@ -347,7 +381,9 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
|||||||
logging.info("No video metadata found for key '%s'; skipping", video_key)
|
logging.info("No video metadata found for key '%s'; skipping", video_key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for (chunk_idx, file_idx), records in tqdm.tqdm(grouped.items(), desc=f"convert videos ({video_key})"):
|
for (chunk_idx, file_idx), records in tqdm.tqdm(
|
||||||
|
grouped.items(), desc=f"convert videos ({video_key})"
|
||||||
|
):
|
||||||
src_path = root / DEFAULT_VIDEO_PATH.format(
|
src_path = root / DEFAULT_VIDEO_PATH.format(
|
||||||
video_key=video_key,
|
video_key=video_key,
|
||||||
chunk_index=chunk_idx,
|
chunk_index=chunk_idx,
|
||||||
@@ -356,7 +392,10 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
|||||||
if not src_path.exists():
|
if not src_path.exists():
|
||||||
raise FileNotFoundError(f"Expected MP4 file not found: {src_path}")
|
raise FileNotFoundError(f"Expected MP4 file not found: {src_path}")
|
||||||
|
|
||||||
records = sorted(records, key=lambda rec: float(rec[f"videos/{video_key}/from_timestamp"]))
|
records = sorted(
|
||||||
|
records,
|
||||||
|
key=lambda rec: float(rec[f"videos/{video_key}/from_timestamp"]),
|
||||||
|
)
|
||||||
|
|
||||||
for record in records:
|
for record in records:
|
||||||
episode_index = int(record["episode_index"])
|
episode_index = int(record["episode_index"])
|
||||||
@@ -373,7 +412,9 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
|||||||
_extract_video_segment(src_path, dest_path, start=start, end=end)
|
_extract_video_segment(src_path, dest_path, start=start, end=end)
|
||||||
|
|
||||||
|
|
||||||
def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, Any]]) -> None:
|
def convert_episodes_metadata(
|
||||||
|
new_root: Path, episode_records: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
logging.info("Reconstructing legacy episodes and episodes_stats JSONL files")
|
logging.info("Reconstructing legacy episodes and episodes_stats JSONL files")
|
||||||
|
|
||||||
episodes_path = new_root / LEGACY_EPISODES_PATH
|
episodes_path = new_root / LEGACY_EPISODES_PATH
|
||||||
@@ -396,7 +437,9 @@ def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, An
|
|||||||
jsonlines.open(episodes_path, mode="w") as episodes_writer,
|
jsonlines.open(episodes_path, mode="w") as episodes_writer,
|
||||||
jsonlines.open(stats_path, mode="w") as stats_writer,
|
jsonlines.open(stats_path, mode="w") as stats_writer,
|
||||||
):
|
):
|
||||||
for record in sorted(episode_records, key=lambda rec: int(rec["episode_index"])):
|
for record in sorted(
|
||||||
|
episode_records, key=lambda rec: int(rec["episode_index"])
|
||||||
|
):
|
||||||
legacy_episode = {
|
legacy_episode = {
|
||||||
key: value
|
key: value
|
||||||
for key, value in record.items()
|
for key, value in record.items()
|
||||||
@@ -407,10 +450,14 @@ def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, An
|
|||||||
and key not in {"dataset_from_index", "dataset_to_index"}
|
and key not in {"dataset_from_index", "dataset_to_index"}
|
||||||
}
|
}
|
||||||
|
|
||||||
serializable_episode = {key: _to_serializable(value) for key, value in legacy_episode.items()}
|
serializable_episode = {
|
||||||
|
key: _to_serializable(value) for key, value in legacy_episode.items()
|
||||||
|
}
|
||||||
episodes_writer.write(serializable_episode)
|
episodes_writer.write(serializable_episode)
|
||||||
|
|
||||||
stats_flat = {key: record[key] for key in record if key.startswith("stats/")}
|
stats_flat = {
|
||||||
|
key: record[key] for key in record if key.startswith("stats/")
|
||||||
|
}
|
||||||
stats_nested = unflatten_dict(stats_flat).get("stats", {})
|
stats_nested = unflatten_dict(stats_flat).get("stats", {})
|
||||||
stats_serialized = serialize_dict(_filter_stats(stats_nested))
|
stats_serialized = serialize_dict(_filter_stats(stats_nested))
|
||||||
stats_writer.write(
|
stats_writer.write(
|
||||||
@@ -453,7 +500,11 @@ def convert_dataset(
|
|||||||
new_root.mkdir(parents=True, exist_ok=True)
|
new_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
episode_records = load_episode_records(root)
|
episode_records = load_episode_records(root)
|
||||||
video_keys = [key for key, ft in load_info(root)["features"].items() if ft.get("dtype") == "video"]
|
video_keys = [
|
||||||
|
key
|
||||||
|
for key, ft in load_info(root)["features"].items()
|
||||||
|
if ft.get("dtype") == "video"
|
||||||
|
]
|
||||||
|
|
||||||
convert_info(root, new_root, episode_records, video_keys)
|
convert_info(root, new_root, episode_records, video_keys)
|
||||||
convert_tasks(root, new_root)
|
convert_tasks(root, new_root)
|
||||||
|
|||||||
Reference in New Issue
Block a user