fix(save images): fixing image saving in dataset tools

This commit is contained in:
CarolinePascal
2026-06-15 17:50:10 +02:00
parent addbf8d7e4
commit 9dd7aee176
3 changed files with 20 additions and 16 deletions
+15 -13
View File
@@ -54,6 +54,7 @@ from .compute_stats import (
compute_relative_action_stats,
)
from .dataset_metadata import LeRobotDatasetMetadata
from .image_writer import write_image
from .io_utils import (
get_parquet_file_size_in_mb,
load_episodes,
@@ -68,6 +69,8 @@ from .utils import (
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEPTH_FILE_PATTERN,
IMAGE_FILE_PATTERN,
VIDEO_DIR,
update_chunk_file_indices,
)
@@ -1157,15 +1160,15 @@ def _save_episode_images_for_video(
# Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image
def save_single_image(i_item_tuple):
i, item = i_item_tuple
img = item[img_key]
# Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
write_image(item[img_key], imgs_dir / frame_pattern.format(frame_index=i))
return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor:
@@ -1197,16 +1200,14 @@ def _save_batch_episodes_images(
hf_dataset = dataset.hf_dataset.with_format(None)
imgs_dataset = hf_dataset.select_columns(img_key)
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
# Define function to save a single image with global frame index
# Defined once outside the loop to avoid repeated closure creation
def save_single_image(i_item_tuple, base_frame_idx, img_key_param):
i, item = i_item_tuple
img = item[img_key_param]
# Use global frame index for naming
if img_key_param in dataset.meta.depth_keys:
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.tiff"), compression="raw")
else:
img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100)
write_image(item[img_key_param], imgs_dir / frame_pattern.format(frame_index=base_frame_idx + i))
return i
episode_durations = []
@@ -1336,10 +1337,11 @@ def _estimate_frame_size_via_calibration(
hf_dataset = dataset.hf_dataset.with_format(None)
sample_indices = range(from_idx, from_idx + num_frames)
# Save calibration frames
# Save calibration frames using the suffix/format the encoder expects.
is_depth = img_key in dataset.meta.depth_keys
frame_pattern = DEPTH_FILE_PATTERN if is_depth else IMAGE_FILE_PATTERN
for i, idx in enumerate(sample_indices):
img = hf_dataset[idx][img_key]
img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100)
write_image(hf_dataset[idx][img_key], calibration_dir / frame_pattern.format(frame_index=i))
# Encode calibration video
calibration_video_path = calibration_dir / "calibration.mp4"
+4 -2
View File
@@ -87,12 +87,14 @@ DATA_DIR = "data"
VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
IMAGE_FILE_PATTERN = "frame-{frame_index:06d}.png"
DEPTH_FILE_PATTERN = "frame-{frame_index:06d}.tiff"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/" + IMAGE_FILE_PATTERN
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/" + DEPTH_FILE_PATTERN
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
+1 -1
View File
@@ -481,7 +481,7 @@ def encode_video_frames(
)
if len(input_list) == 0:
raise FileNotFoundError(f"No images found in {imgs_dir}.")
raise FileNotFoundError(f"No images with suffix {suffix} found in {imgs_dir}.")
with Image.open(input_list[0]) as dummy_image:
width, height = dummy_image.size