mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
refactor: support custom progress parquet overlays (#3640)
This commit is contained in:
@@ -15,10 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||||
|
|
||||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||||
of the source video, draws a progress line on each frame, and writes the result.
|
of the source video, draws a progress line on each frame, and writes the result.
|
||||||
|
The progress data is read from a parquet file that lives alongside the dataset
|
||||||
|
(configurable via ``--progress-file``).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python examples/dataset/create_progress_videos.py \
|
python examples/dataset/create_progress_videos.py \
|
||||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
|||||||
TASK_FONT_SCALE = 0.55
|
TASK_FONT_SCALE = 0.55
|
||||||
|
|
||||||
|
|
||||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
def download_episode_metadata(
|
||||||
"""Download only the metadata and sarm_progress files for a dataset.
|
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||||
|
) -> Path:
|
||||||
|
"""Download only the metadata and per-frame progress file for a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: HuggingFace dataset repository ID.
|
repo_id: HuggingFace dataset repository ID.
|
||||||
episode: Episode index (used for logging only; all meta is fetched).
|
episode: Episode index (used for logging only; all meta is fetched).
|
||||||
|
progress_file: Filename of the per-frame progress parquet inside the
|
||||||
|
dataset repo.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Local cache path for the downloaded snapshot.
|
Local cache path for the downloaded snapshot.
|
||||||
"""
|
"""
|
||||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||||
local_path = Path(
|
local_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
allow_patterns=["meta/**", progress_file],
|
||||||
ignore_patterns=["*.mp4"],
|
ignore_patterns=["*.mp4"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
|||||||
return video_path
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
def load_progress_data(
|
||||||
"""Load sarm_progress values for an episode.
|
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""Load per-frame progress values for an episode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_path: Dataset cache root.
|
local_path: Dataset cache root.
|
||||||
episode: Episode index.
|
episode: Episode index.
|
||||||
|
progress_file: Filename of the per-frame progress parquet.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||||
"""
|
"""
|
||||||
parquet_path = local_path / "sarm_progress.parquet"
|
parquet_path = local_path / progress_file
|
||||||
if not parquet_path.exists():
|
if not parquet_path.exists():
|
||||||
logging.warning("sarm_progress.parquet not found")
|
logging.warning("%s not found", progress_file)
|
||||||
return None
|
return None
|
||||||
df = pd.read_parquet(parquet_path)
|
df = pd.read_parquet(parquet_path)
|
||||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||||
episode_df = df[df["episode_index"] == episode].copy()
|
episode_df = df[df["episode_index"] == episode].copy()
|
||||||
if episode_df.empty:
|
if episode_df.empty:
|
||||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||||
return None
|
return None
|
||||||
episode_df = episode_df.sort_values("frame_index")
|
episode_df = episode_df.sort_values("frame_index")
|
||||||
|
|
||||||
@@ -576,6 +585,7 @@ def process_dataset(
|
|||||||
camera_key: str | None,
|
camera_key: str | None,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
create_gif: bool = False,
|
create_gif: bool = False,
|
||||||
|
progress_file: str = "sarm_progress.parquet",
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||||
|
|
||||||
@@ -585,6 +595,8 @@ def process_dataset(
|
|||||||
camera_key: Camera key to use, or None for auto-selection.
|
camera_key: Camera key to use, or None for auto-selection.
|
||||||
output_dir: Directory to write output files.
|
output_dir: Directory to write output files.
|
||||||
create_gif: If True, also generate a GIF from the MP4.
|
create_gif: If True, also generate a GIF from the MP4.
|
||||||
|
progress_file: Filename of the per-frame progress parquet inside the
|
||||||
|
dataset repo.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the final output file, or None on failure.
|
Path to the final output file, or None on failure.
|
||||||
@@ -592,7 +604,7 @@ def process_dataset(
|
|||||||
safe_name = repo_id.replace("/", "_")
|
safe_name = repo_id.replace("/", "_")
|
||||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||||
|
|
||||||
local_path = download_episode_metadata(repo_id, episode)
|
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||||
logging.info(" Local cache: %s", local_path)
|
logging.info(" Local cache: %s", local_path)
|
||||||
|
|
||||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||||
@@ -600,9 +612,9 @@ def process_dataset(
|
|||||||
|
|
||||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||||
|
|
||||||
progress_data = load_progress_data(local_path, episode)
|
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||||
if progress_data is None:
|
if progress_data is None:
|
||||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logging.info(" Progress frames: %d", len(progress_data))
|
logging.info(" Progress frames: %d", len(progress_data))
|
||||||
@@ -627,7 +639,7 @@ def process_dataset(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
@@ -658,6 +670,15 @@ def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Also generate a GIF from the MP4 output.",
|
help="Also generate a GIF from the MP4 output.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress-file",
|
||||||
|
type=str,
|
||||||
|
default="sarm_progress.parquet",
|
||||||
|
help=(
|
||||||
|
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||||
|
"(default: 'sarm_progress.parquet')."
|
||||||
|
),
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
@@ -670,6 +691,7 @@ def main() -> None:
|
|||||||
camera_key=args.camera_key,
|
camera_key=args.camera_key,
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
create_gif=args.gif,
|
create_gif=args.gif,
|
||||||
|
progress_file=args.progress_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
Reference in New Issue
Block a user