mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Add create reward visualization (#3155)
* Add create reward visualization and multimodal analysis tool * add example for creating progress video for sarm * nit * precommit * refactor: address review comments on create_progress_videos.py - Add shebang and Apache 2.0 license header - Replace hardcoded absolute OUTPUT_DIR with relative default (./progress_videos) - Add argparse CLI (--repo-id, --episode, --camera-key, --output-dir, --gif) - Wrap entrypoint in def main() - Replace all print() with logging - Use logging.error/warning instead of traceback.print_exc - Release VideoCapture via try/finally; consolidate triple-open into single seek - Eliminate intermediate clip file: seek directly via CAP_PROP_POS_MSEC - Make MP4 the default output, GIF opt-in via --gif flag - Add return types to all functions - Add Args/Returns docstrings - Use descriptive variable names throughout Made-with: Cursor * refactor: move create_progress_videos.py to examples/dataset/ for consistency Made-with: Cursor * refactor: address PR review comments on create_progress_videos.py - Replace Unicode ellipsis and multiplication sign with ASCII equivalents - Fix step numbering from 1-5 to 1-4 (only 4 actual steps) - Move frame_width reading into convert_mp4_to_gif - Remove unused text_height variable Made-with: Cursor
This commit is contained in:
@@ -0,0 +1,680 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
||||||
|
|
||||||
|
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||||
|
of the source video, draws a progress line on each frame, and writes the result.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/dataset/create_progress_videos.py \
|
||||||
|
--repo-id lerobot-data-collection/level2_final_quality3 \
|
||||||
|
--episode 1100
|
||||||
|
|
||||||
|
python examples/dataset/create_progress_videos.py \
|
||||||
|
--repo-id lerobot-data-collection/level2_final_quality3 \
|
||||||
|
--episode 1100 \
|
||||||
|
--camera-key observation.images.top \
|
||||||
|
--output-dir ./my_videos \
|
||||||
|
--gif
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
GRAPH_Y_TOP_FRAC = 0.01
|
||||||
|
GRAPH_Y_BOT_FRAC = 0.99
|
||||||
|
LINE_THICKNESS = 3
|
||||||
|
SHADOW_THICKNESS = 6
|
||||||
|
REF_ALPHA = 0.45
|
||||||
|
FILL_ALPHA = 0.55
|
||||||
|
SCORE_FONT_SCALE = 0.8
|
||||||
|
TASK_FONT_SCALE = 0.55
|
||||||
|
|
||||||
|
|
||||||
|
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
||||||
|
"""Download only the metadata and sarm_progress files for a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: HuggingFace dataset repository ID.
|
||||||
|
episode: Episode index (used for logging only; all meta is fetched).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Local cache path for the downloaded snapshot.
|
||||||
|
"""
|
||||||
|
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
||||||
|
local_path = Path(
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
||||||
|
ignore_patterns=["*.mp4"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return local_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_episode_meta(local_path: Path, episode: int, camera_key: str | None) -> dict:
|
||||||
|
"""Read info.json and episode parquet to resolve fps, video path, and timestamps.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Local cache directory containing meta/.
|
||||||
|
episode: Episode index to look up.
|
||||||
|
camera_key: Camera observation key (e.g. "observation.images.base").
|
||||||
|
If None, the first available video key is used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: fps, camera, video_rel, chunk_index, file_index,
|
||||||
|
from_ts, to_ts, task_name.
|
||||||
|
"""
|
||||||
|
info = json.loads((local_path / "meta" / "info.json").read_text())
|
||||||
|
fps = info["fps"]
|
||||||
|
features = info["features"]
|
||||||
|
|
||||||
|
video_keys = [k for k, v in features.items() if v.get("dtype") == "video"]
|
||||||
|
if not video_keys:
|
||||||
|
raise RuntimeError("No video keys found in dataset features")
|
||||||
|
|
||||||
|
if camera_key is not None:
|
||||||
|
if camera_key not in video_keys:
|
||||||
|
raise RuntimeError(f"camera_key='{camera_key}' not found. Available: {video_keys}")
|
||||||
|
selected_camera = camera_key
|
||||||
|
else:
|
||||||
|
selected_camera = video_keys[0]
|
||||||
|
logging.info(" fps=%d camera='%s' all_cams=%s", fps, selected_camera, video_keys)
|
||||||
|
|
||||||
|
episode_rows = []
|
||||||
|
for parquet_file in sorted((local_path / "meta" / "episodes").glob("**/*.parquet")):
|
||||||
|
episode_rows.append(pd.read_parquet(parquet_file))
|
||||||
|
episode_df = pd.concat(episode_rows, ignore_index=True)
|
||||||
|
row = episode_df[episode_df["episode_index"] == episode]
|
||||||
|
if row.empty:
|
||||||
|
raise RuntimeError(f"Episode {episode} not found in episode metadata")
|
||||||
|
row = row.iloc[0]
|
||||||
|
|
||||||
|
chunk_col = f"videos/{selected_camera}/chunk_index"
|
||||||
|
file_col = f"videos/{selected_camera}/file_index"
|
||||||
|
ts_from_col = f"videos/{selected_camera}/from_timestamp"
|
||||||
|
ts_to_col = f"videos/{selected_camera}/to_timestamp"
|
||||||
|
|
||||||
|
if chunk_col not in row.index:
|
||||||
|
chunk_col = f"{selected_camera}/chunk_index"
|
||||||
|
file_col = f"{selected_camera}/file_index"
|
||||||
|
ts_from_col = f"{selected_camera}/from_timestamp"
|
||||||
|
ts_to_col = f"{selected_camera}/to_timestamp"
|
||||||
|
if chunk_col not in row.index:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot find video metadata columns for {selected_camera}.\nAvailable: {list(row.index)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_index = int(row[chunk_col])
|
||||||
|
file_index = int(row[file_col])
|
||||||
|
from_timestamp = float(row[ts_from_col])
|
||||||
|
to_timestamp = float(row[ts_to_col])
|
||||||
|
|
||||||
|
video_template = info.get(
|
||||||
|
"video_path", "videos/{video_key}/chunk-{chunk_index:03d}/file-{file_index:03d}.mp4"
|
||||||
|
)
|
||||||
|
video_rel = video_template.format(
|
||||||
|
video_key=selected_camera,
|
||||||
|
chunk_index=chunk_index,
|
||||||
|
file_index=file_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_name = _resolve_task_name(row, local_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"fps": fps,
|
||||||
|
"camera": selected_camera,
|
||||||
|
"video_rel": video_rel,
|
||||||
|
"chunk_index": chunk_index,
|
||||||
|
"file_index": file_index,
|
||||||
|
"from_ts": from_timestamp,
|
||||||
|
"to_ts": to_timestamp,
|
||||||
|
"task_name": task_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_task_name(row: pd.Series, local_path: Path) -> str:
|
||||||
|
"""Best-effort extraction of the task name for an episode row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
row: Single-episode row from the episodes parquet.
|
||||||
|
local_path: Dataset cache root.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task name string, or empty string if unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if "tasks" in row.index and row["tasks"] is not None:
|
||||||
|
tasks_val = row["tasks"]
|
||||||
|
if isinstance(tasks_val, (list, tuple, np.ndarray)) and len(tasks_val) > 0:
|
||||||
|
return str(tasks_val[0])
|
||||||
|
return str(tasks_val).strip("[]'")
|
||||||
|
|
||||||
|
tasks_parquet = local_path / "meta" / "tasks.parquet"
|
||||||
|
if tasks_parquet.exists():
|
||||||
|
tasks_df = pd.read_parquet(tasks_parquet)
|
||||||
|
task_idx = int(row.get("task_index", 0)) if "task_index" in row.index else 0
|
||||||
|
match = tasks_df[tasks_df["task_index"] == task_idx]
|
||||||
|
if not match.empty:
|
||||||
|
return str(match.index[0])
|
||||||
|
except Exception as exc:
|
||||||
|
logging.warning("Could not load task name: %s", exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
||||||
|
"""Download the specific video file if not already cached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: HuggingFace dataset repository ID.
|
||||||
|
local_path: Local cache directory.
|
||||||
|
video_rel: Relative path to the video file within the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to the downloaded video file.
|
||||||
|
"""
|
||||||
|
video_path = local_path / video_rel
|
||||||
|
if video_path.exists():
|
||||||
|
logging.info(" Video already cached: %s", video_path)
|
||||||
|
return video_path
|
||||||
|
logging.info("[2/4] Downloading video file %s ...", video_rel)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=str(local_path),
|
||||||
|
allow_patterns=[video_rel],
|
||||||
|
)
|
||||||
|
if not video_path.exists():
|
||||||
|
raise RuntimeError(f"Video not found after download: {video_path}")
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
||||||
|
"""Load sarm_progress values for an episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
local_path: Dataset cache root.
|
||||||
|
episode: Episode index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||||
|
"""
|
||||||
|
parquet_path = local_path / "sarm_progress.parquet"
|
||||||
|
if not parquet_path.exists():
|
||||||
|
logging.warning("sarm_progress.parquet not found")
|
||||||
|
return None
|
||||||
|
df = pd.read_parquet(parquet_path)
|
||||||
|
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
||||||
|
episode_df = df[df["episode_index"] == episode].copy()
|
||||||
|
if episode_df.empty:
|
||||||
|
logging.warning("No sarm_progress rows for episode %d", episode)
|
||||||
|
return None
|
||||||
|
episode_df = episode_df.sort_values("frame_index")
|
||||||
|
|
||||||
|
if "progress_dense" in episode_df.columns and episode_df["progress_dense"].notna().any():
|
||||||
|
progress_column = "progress_dense"
|
||||||
|
elif "progress_sparse" in episode_df.columns:
|
||||||
|
progress_column = "progress_sparse"
|
||||||
|
else:
|
||||||
|
progress_columns = [c for c in episode_df.columns if "progress" in c.lower()]
|
||||||
|
if not progress_columns:
|
||||||
|
return None
|
||||||
|
progress_column = progress_columns[0]
|
||||||
|
|
||||||
|
logging.info(" Using progress column: '%s'", progress_column)
|
||||||
|
return episode_df[["frame_index", progress_column]].rename(columns={progress_column: "progress"}).values
|
||||||
|
|
||||||
|
|
||||||
|
def _precompute_pixel_coords(
|
||||||
|
progress_data: np.ndarray,
|
||||||
|
num_frames: int,
|
||||||
|
frame_width: int,
|
||||||
|
frame_height: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Map progress samples to pixel coordinates for overlay drawing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_data: (N, 2) array of (frame_index, progress).
|
||||||
|
num_frames: Total number of video frames.
|
||||||
|
frame_width: Video width in pixels.
|
||||||
|
frame_height: Video height in pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(N, 2) array of (x, y) pixel coordinates.
|
||||||
|
"""
|
||||||
|
frame_indices = progress_data[:, 0].astype(float)
|
||||||
|
progress_values = np.clip(progress_data[:, 1].astype(float), 0.0, 1.0)
|
||||||
|
|
||||||
|
y_top = int(frame_height * GRAPH_Y_TOP_FRAC)
|
||||||
|
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
|
||||||
|
graph_height = y_bot - y_top
|
||||||
|
|
||||||
|
x_coords = (frame_indices / (num_frames - 1) * (frame_width - 1)).astype(int)
|
||||||
|
y_coords = (y_bot - progress_values * graph_height).astype(int)
|
||||||
|
|
||||||
|
return np.stack([x_coords, y_coords], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _progress_color(normalized_position: float) -> tuple[int, int, int]:
|
||||||
|
"""Interpolate BGR color from red to green based on position in [0, 1].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
normalized_position: Value in [0, 1] indicating how far along the episode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BGR color tuple.
|
||||||
|
"""
|
||||||
|
red = int(255 * (1.0 - normalized_position))
|
||||||
|
green = int(255 * normalized_position)
|
||||||
|
return (0, green, red)
|
||||||
|
|
||||||
|
|
||||||
|
def _prerender_fill_polygon(
|
||||||
|
pixel_coords: np.ndarray,
|
||||||
|
frame_width: int,
|
||||||
|
frame_height: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Pre-render the grey fill polygon under the progress curve as a BGRA image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_coords: (N, 2) array of (x, y) pixel coordinates.
|
||||||
|
frame_width: Video width in pixels.
|
||||||
|
frame_height: Video height in pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BGRA image array of shape (frame_height, frame_width, 4).
|
||||||
|
"""
|
||||||
|
y_bot = int(frame_height * GRAPH_Y_BOT_FRAC)
|
||||||
|
fill_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
|
||||||
|
polygon = np.concatenate(
|
||||||
|
[
|
||||||
|
pixel_coords,
|
||||||
|
[[pixel_coords[-1][0], y_bot], [pixel_coords[0][0], y_bot]],
|
||||||
|
],
|
||||||
|
axis=0,
|
||||||
|
).astype(np.int32)
|
||||||
|
cv2.fillPoly(fill_image, [polygon], color=(128, 128, 128, int(255 * FILL_ALPHA)))
|
||||||
|
return fill_image
|
||||||
|
|
||||||
|
|
||||||
|
def _alpha_composite_region(base: np.ndarray, overlay_bgra: np.ndarray, x_limit: int) -> None:
|
||||||
|
"""Blend BGRA overlay onto BGR base in-place, up to x_limit columns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: BGR frame to draw on (modified in-place).
|
||||||
|
overlay_bgra: BGRA overlay image.
|
||||||
|
x_limit: Only blend columns [0, x_limit).
|
||||||
|
"""
|
||||||
|
if x_limit <= 0:
|
||||||
|
return
|
||||||
|
region_base = base[:, :x_limit]
|
||||||
|
region_overlay = overlay_bgra[:, :x_limit]
|
||||||
|
alpha = region_overlay[:, :, 3:4].astype(np.float32) / 255.0
|
||||||
|
region_base[:] = np.clip(
|
||||||
|
region_overlay[:, :, :3].astype(np.float32) * alpha + region_base.astype(np.float32) * (1.0 - alpha),
|
||||||
|
0,
|
||||||
|
255,
|
||||||
|
).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def _draw_text_outlined(
|
||||||
|
frame: np.ndarray,
|
||||||
|
text: str,
|
||||||
|
position: tuple[int, int],
|
||||||
|
font_scale: float,
|
||||||
|
thickness: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""Draw white text with a dark outline for readability on any background.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: BGR image to draw on (modified in-place).
|
||||||
|
text: String to render.
|
||||||
|
position: (x, y) bottom-left corner of the text.
|
||||||
|
font_scale: OpenCV font scale.
|
||||||
|
thickness: Text stroke thickness.
|
||||||
|
"""
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
cv2.putText(frame, text, position, font, font_scale, (0, 0, 0), thickness + 2, cv2.LINE_AA)
|
||||||
|
cv2.putText(frame, text, position, font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
||||||
|
|
||||||
|
|
||||||
|
def composite_progress_video(
|
||||||
|
video_path: Path,
|
||||||
|
from_timestamp: float,
|
||||||
|
to_timestamp: float,
|
||||||
|
progress_data: np.ndarray,
|
||||||
|
output_path: Path,
|
||||||
|
fps: float,
|
||||||
|
task_name: str = "",
|
||||||
|
) -> Path:
|
||||||
|
"""Read episode frames by seeking into the source video, draw progress overlay, write output.
|
||||||
|
|
||||||
|
Uses cv2.CAP_PROP_POS_MSEC to seek directly into the source video,
|
||||||
|
eliminating the need for an intermediate clip file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the full source video file.
|
||||||
|
from_timestamp: Start timestamp of the episode in seconds.
|
||||||
|
to_timestamp: End timestamp of the episode in seconds.
|
||||||
|
progress_data: (N, 2) array of (frame_index, progress).
|
||||||
|
output_path: Path to write the output MP4.
|
||||||
|
fps: Frames per second for the output video.
|
||||||
|
task_name: Optional task name to display at the top of the video.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the written output file (MP4).
|
||||||
|
"""
|
||||||
|
capture = cv2.VideoCapture(str(video_path))
|
||||||
|
try:
|
||||||
|
capture.set(cv2.CAP_PROP_POS_MSEC, from_timestamp * 1000)
|
||||||
|
|
||||||
|
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
duration_seconds = to_timestamp - from_timestamp
|
||||||
|
num_frames = int(round(duration_seconds * fps))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
" Video: %dx%d, %d frames @ %.1f fps (%.2fs)",
|
||||||
|
frame_width,
|
||||||
|
frame_height,
|
||||||
|
num_frames,
|
||||||
|
fps,
|
||||||
|
duration_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_coords = _precompute_pixel_coords(progress_data, num_frames, frame_width, frame_height)
|
||||||
|
y_ref = int(frame_height * GRAPH_Y_TOP_FRAC)
|
||||||
|
|
||||||
|
fill_image = _prerender_fill_polygon(pixel_coords, frame_width, frame_height)
|
||||||
|
|
||||||
|
ref_line_image = np.zeros((frame_height, frame_width, 4), dtype=np.uint8)
|
||||||
|
cv2.line(
|
||||||
|
ref_line_image,
|
||||||
|
(0, y_ref),
|
||||||
|
(frame_width - 1, y_ref),
|
||||||
|
(200, 200, 200, int(255 * REF_ALPHA)),
|
||||||
|
1,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_indices = progress_data[:, 0].astype(int)
|
||||||
|
progress_values = progress_data[:, 1].astype(float)
|
||||||
|
|
||||||
|
logging.info("[3/4] Compositing %d frames ...", num_frames)
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
|
||||||
|
|
||||||
|
for frame_idx in range(num_frames):
|
||||||
|
ret, frame = capture.read()
|
||||||
|
if not ret:
|
||||||
|
break
|
||||||
|
|
||||||
|
drawn_count = int(np.searchsorted(frame_indices, frame_idx, side="right"))
|
||||||
|
x_current = (
|
||||||
|
int(pixel_coords[min(drawn_count, len(pixel_coords)) - 1][0]) + 1 if drawn_count > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
_alpha_composite_region(frame, ref_line_image, frame_width)
|
||||||
|
_alpha_composite_region(frame, fill_image, x_current)
|
||||||
|
|
||||||
|
if drawn_count >= 2:
|
||||||
|
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
|
||||||
|
line_color = _progress_color(time_position)
|
||||||
|
points = pixel_coords[:drawn_count].reshape(-1, 1, 2).astype(np.int32)
|
||||||
|
cv2.polylines(
|
||||||
|
frame,
|
||||||
|
[points],
|
||||||
|
isClosed=False,
|
||||||
|
color=(255, 255, 255),
|
||||||
|
thickness=SHADOW_THICKNESS,
|
||||||
|
lineType=cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
cv2.polylines(
|
||||||
|
frame,
|
||||||
|
[points],
|
||||||
|
isClosed=False,
|
||||||
|
color=line_color,
|
||||||
|
thickness=LINE_THICKNESS,
|
||||||
|
lineType=cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
if drawn_count > 0:
|
||||||
|
score = float(progress_values[min(drawn_count, len(progress_values)) - 1])
|
||||||
|
score_text = f"{score:.2f}"
|
||||||
|
(text_width, _), _ = cv2.getTextSize(
|
||||||
|
score_text, cv2.FONT_HERSHEY_SIMPLEX, SCORE_FONT_SCALE, 2
|
||||||
|
)
|
||||||
|
score_x = frame_width - text_width - 12
|
||||||
|
score_y = frame_height - 12
|
||||||
|
time_position = (drawn_count - 1) / max(len(progress_values) - 1, 1)
|
||||||
|
score_color = _progress_color(time_position)
|
||||||
|
cv2.putText(
|
||||||
|
frame,
|
||||||
|
score_text,
|
||||||
|
(score_x, score_y),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
SCORE_FONT_SCALE,
|
||||||
|
(0, 0, 0),
|
||||||
|
4,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
cv2.putText(
|
||||||
|
frame,
|
||||||
|
score_text,
|
||||||
|
(score_x, score_y),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
SCORE_FONT_SCALE,
|
||||||
|
score_color,
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_name:
|
||||||
|
(text_width, _), _ = cv2.getTextSize(task_name, cv2.FONT_HERSHEY_SIMPLEX, TASK_FONT_SCALE, 1)
|
||||||
|
task_x = max((frame_width - text_width) // 2, 4)
|
||||||
|
_draw_text_outlined(frame, task_name, (task_x, 22), TASK_FONT_SCALE)
|
||||||
|
|
||||||
|
writer.write(frame)
|
||||||
|
if frame_idx % 100 == 0:
|
||||||
|
logging.info(" Frame %d/%d ...", frame_idx, num_frames)
|
||||||
|
|
||||||
|
writer.release()
|
||||||
|
finally:
|
||||||
|
capture.release()
|
||||||
|
|
||||||
|
logging.info(" MP4 written: %s", output_path)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mp4_to_gif(mp4_path: Path) -> Path:
|
||||||
|
"""Convert an MP4 to an optimized GIF using ffmpeg palette generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp4_path: Path to the source MP4 file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the generated GIF file.
|
||||||
|
"""
|
||||||
|
capture = cv2.VideoCapture(str(mp4_path))
|
||||||
|
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
capture.release()
|
||||||
|
|
||||||
|
gif_path = mp4_path.with_suffix(".gif")
|
||||||
|
palette_path = mp4_path.parent / "_palette.png"
|
||||||
|
|
||||||
|
logging.info("[4/4] Converting to GIF ...")
|
||||||
|
result_palette = subprocess.run( # nosec B607
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
str(mp4_path),
|
||||||
|
"-vf",
|
||||||
|
f"fps=10,scale={frame_width}:-1:flags=lanczos,palettegen=max_colors=128:stats_mode=diff",
|
||||||
|
"-update",
|
||||||
|
"1",
|
||||||
|
str(palette_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if result_palette.returncode != 0:
|
||||||
|
logging.warning("palettegen failed:\n%s", result_palette.stderr[-500:])
|
||||||
|
|
||||||
|
result_gif = subprocess.run( # nosec B607
|
||||||
|
[
|
||||||
|
"ffmpeg",
|
||||||
|
"-y",
|
||||||
|
"-i",
|
||||||
|
str(mp4_path),
|
||||||
|
"-i",
|
||||||
|
str(palette_path),
|
||||||
|
"-filter_complex",
|
||||||
|
f"fps=10,scale={frame_width}:-1:flags=lanczos[v];[v][1:v]paletteuse=dither=bayer:bayer_scale=3",
|
||||||
|
str(gif_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
if result_gif.returncode != 0:
|
||||||
|
logging.warning("GIF encode failed:\n%s", result_gif.stderr[-500:])
|
||||||
|
|
||||||
|
palette_path.unlink(missing_ok=True)
|
||||||
|
logging.info(" GIF written: %s", gif_path)
|
||||||
|
return gif_path
|
||||||
|
|
||||||
|
|
||||||
|
def process_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
episode: int,
|
||||||
|
camera_key: str | None,
|
||||||
|
output_dir: Path,
|
||||||
|
create_gif: bool = False,
|
||||||
|
) -> Path | None:
|
||||||
|
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: HuggingFace dataset repository ID.
|
||||||
|
episode: Episode index.
|
||||||
|
camera_key: Camera key to use, or None for auto-selection.
|
||||||
|
output_dir: Directory to write output files.
|
||||||
|
create_gif: If True, also generate a GIF from the MP4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the final output file, or None on failure.
|
||||||
|
"""
|
||||||
|
safe_name = repo_id.replace("/", "_")
|
||||||
|
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||||
|
|
||||||
|
local_path = download_episode_metadata(repo_id, episode)
|
||||||
|
logging.info(" Local cache: %s", local_path)
|
||||||
|
|
||||||
|
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||||
|
logging.info(" Episode meta: %s", episode_meta)
|
||||||
|
|
||||||
|
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||||
|
|
||||||
|
progress_data = load_progress_data(local_path, episode)
|
||||||
|
if progress_data is None:
|
||||||
|
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logging.info(" Progress frames: %d", len(progress_data))
|
||||||
|
|
||||||
|
output_path = output_dir / f"{safe_name}_ep{episode}_progress.mp4"
|
||||||
|
final_path = composite_progress_video(
|
||||||
|
video_path=video_path,
|
||||||
|
from_timestamp=episode_meta["from_ts"],
|
||||||
|
to_timestamp=episode_meta["to_ts"],
|
||||||
|
progress_data=progress_data,
|
||||||
|
output_path=output_path,
|
||||||
|
fps=episode_meta["fps"],
|
||||||
|
task_name=episode_meta.get("task_name", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
if create_gif:
|
||||||
|
final_path = convert_mp4_to_gif(final_path)
|
||||||
|
|
||||||
|
logging.info("Done: %s", final_path)
|
||||||
|
return final_path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repo-id",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="HuggingFace dataset repository ID (e.g. 'lerobot-data-collection/level2_final_quality3').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--episode",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Episode index to visualize.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--camera-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Camera observation key (e.g. 'observation.images.base'). Auto-selects first camera if omitted.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("progress_videos"),
|
||||||
|
help="Directory to write output files (default: ./progress_videos).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gif",
|
||||||
|
action="store_true",
|
||||||
|
help="Also generate a GIF from the MP4 output.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
result = process_dataset(
|
||||||
|
repo_id=args.repo_id,
|
||||||
|
episode=args.episode,
|
||||||
|
camera_key=args.camera_key,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
create_gif=args.gif,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logging.info("Output: %s", result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user