refactor: support custom progress parquet overlays (#3640)

This commit is contained in:
Khalil Meftah
2026-05-21 14:32:10 +02:00
committed by GitHub
parent f4b834844e
commit bac4f61eae
+37 -15
View File
@@ -15,10 +15,12 @@
# limitations under the License. # limitations under the License.
""" """
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes. Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result. of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage: Usage:
python examples/dataset/create_progress_videos.py \ python examples/dataset/create_progress_videos.py \
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55 TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path: def download_episode_metadata(
"""Download only the metadata and sarm_progress files for a dataset. repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
Args: Args:
repo_id: HuggingFace dataset repository ID. repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched). episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns: Returns:
Local cache path for the downloaded snapshot. Local cache path for the downloaded snapshot.
""" """
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode) logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
local_path = Path( local_path = Path(
snapshot_download( snapshot_download(
repo_id=repo_id, repo_id=repo_id,
repo_type="dataset", repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"], allow_patterns=["meta/**", progress_file],
ignore_patterns=["*.mp4"], ignore_patterns=["*.mp4"],
) )
) )
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None: def load_progress_data(
"""Load sarm_progress values for an episode. local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
Args: Args:
local_path: Dataset cache root. local_path: Dataset cache root.
episode: Episode index. episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns: Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable. Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
""" """
parquet_path = local_path / "sarm_progress.parquet" parquet_path = local_path / progress_file
if not parquet_path.exists(): if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found") logging.warning("%s not found", progress_file)
return None return None
df = pd.read_parquet(parquet_path) df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns)) logging.info(" %s columns: %s", progress_file, list(df.columns))
episode_df = df[df["episode_index"] == episode].copy() episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty: if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode) logging.warning("No progress rows for episode %d in %s", episode, progress_file)
return None return None
episode_df = episode_df.sort_values("frame_index") episode_df = episode_df.sort_values("frame_index")
@@ -576,6 +585,7 @@ def process_dataset(
camera_key: str | None, camera_key: str | None,
output_dir: Path, output_dir: Path,
create_gif: bool = False, create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None: ) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output. """Full pipeline: download, extract metadata, composite progress, write output.
@@ -585,6 +595,8 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection. camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files. output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4. create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns: Returns:
Path to the final output file, or None on failure. Path to the final output file, or None on failure.
@@ -592,7 +604,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_") safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode) logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode) local_path = download_episode_metadata(repo_id, episode, progress_file)
logging.info(" Local cache: %s", local_path) logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key) episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -600,9 +612,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"]) video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode) progress_data = load_progress_data(local_path, episode, progress_file)
if progress_data is None: if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.") logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
return None return None
logging.info(" Progress frames: %d", len(progress_data)) logging.info(" Progress frames: %d", len(progress_data))
@@ -627,7 +639,7 @@ def process_dataset(
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes." description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
) )
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
@@ -658,6 +670,15 @@ def main() -> None:
action="store_true", action="store_true",
help="Also generate a GIF from the MP4 output.", help="Also generate a GIF from the MP4 output.",
) )
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -670,6 +691,7 @@ def main() -> None:
camera_key=args.camera_key, camera_key=args.camera_key,
output_dir=args.output_dir, output_dir=args.output_dir,
create_gif=args.gif, create_gif=args.gif,
progress_file=args.progress_file,
) )
if result: if result: