diff --git a/examples/dataset/create_progress_videos.py b/examples/dataset/create_progress_videos.py index 5f98d2cea..cb85a9d3a 100644 --- a/examples/dataset/create_progress_videos.py +++ b/examples/dataset/create_progress_videos.py @@ -15,10 +15,12 @@ # limitations under the License. """ -Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes. +Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes. Downloads datasets from HuggingFace, seeks directly into the episode segment of the source video, draws a progress line on each frame, and writes the result. +The progress data is read from a parquet file that lives alongside the dataset +(configurable via ``--progress-file``). Usage: python examples/dataset/create_progress_videos.py \ @@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8 TASK_FONT_SCALE = 0.55 -def download_episode_metadata(repo_id: str, episode: int) -> Path: - """Download only the metadata and sarm_progress files for a dataset. +def download_episode_metadata( + repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet" +) -> Path: + """Download only the metadata and per-frame progress file for a dataset. Args: repo_id: HuggingFace dataset repository ID. episode: Episode index (used for logging only; all meta is fetched). + progress_file: Filename of the per-frame progress parquet inside the + dataset repo. Returns: Local cache path for the downloaded snapshot. """ - logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode) + logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode) local_path = Path( snapshot_download( repo_id=repo_id, repo_type="dataset", - allow_patterns=["meta/**", "sarm_progress.parquet"], + allow_patterns=["meta/**", progress_file], ignore_patterns=["*.mp4"], ) ) @@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path: return video_path -def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None: - """Load sarm_progress values for an episode. +def load_progress_data( + local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet" +) -> np.ndarray | None: + """Load per-frame progress values for an episode. Args: local_path: Dataset cache root. episode: Episode index. + progress_file: Filename of the per-frame progress parquet. Returns: Sorted (N, 2) array of (frame_index, progress), or None if unavailable. """ - parquet_path = local_path / "sarm_progress.parquet" + parquet_path = local_path / progress_file if not parquet_path.exists(): - logging.warning("sarm_progress.parquet not found") + logging.warning("%s not found", progress_file) return None df = pd.read_parquet(parquet_path) - logging.info(" sarm_progress.parquet columns: %s", list(df.columns)) + logging.info(" %s columns: %s", progress_file, list(df.columns)) episode_df = df[df["episode_index"] == episode].copy() if episode_df.empty: - logging.warning("No sarm_progress rows for episode %d", episode) + logging.warning("No progress rows for episode %d in %s", episode, progress_file) return None episode_df = episode_df.sort_values("frame_index") @@ -576,6 +585,7 @@ def process_dataset( camera_key: str | None, output_dir: Path, create_gif: bool = False, + progress_file: str = "sarm_progress.parquet", ) -> Path | None: """Full pipeline: download, extract metadata, composite progress, write output. @@ -585,6 +595,8 @@ def process_dataset( camera_key: Camera key to use, or None for auto-selection. output_dir: Directory to write output files. create_gif: If True, also generate a GIF from the MP4. + progress_file: Filename of the per-frame progress parquet inside the + dataset repo. Returns: Path to the final output file, or None on failure. @@ -592,7 +604,7 @@ def process_dataset( safe_name = repo_id.replace("/", "_") logging.info("Processing: %s | episode %d", repo_id, episode) - local_path = download_episode_metadata(repo_id, episode) + local_path = download_episode_metadata(repo_id, episode, progress_file) logging.info(" Local cache: %s", local_path) episode_meta = load_episode_meta(local_path, episode, camera_key) @@ -600,9 +612,9 @@ def process_dataset( video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"]) - progress_data = load_progress_data(local_path, episode) + progress_data = load_progress_data(local_path, episode, progress_file) if progress_data is None: - logging.error("Could not load sarm_progress data. Skipping overlay.") + logging.error("Could not load progress data from %s. Skipping overlay.", progress_file) return None logging.info(" Progress frames: %d", len(progress_data)) @@ -627,7 +639,7 @@ def process_dataset( def main() -> None: parser = argparse.ArgumentParser( - description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes." + description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes." ) parser.add_argument( "--repo-id", @@ -658,6 +670,15 @@ def main() -> None: action="store_true", help="Also generate a GIF from the MP4 output.", ) + parser.add_argument( + "--progress-file", + type=str, + default="sarm_progress.parquet", + help=( + "Filename of the per-frame progress parquet inside the dataset repo " + "(default: 'sarm_progress.parquet')." + ), + ) args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") @@ -670,6 +691,7 @@ def main() -> None: camera_key=args.camera_key, output_dir=args.output_dir, create_gif=args.gif, + progress_file=args.progress_file, ) if result: