chore(dataset): basic house-keeping (#3170)

This commit is contained in:
Steven Palma
2026-03-15 22:12:09 -07:00
committed by GitHub
parent 7c2ec31793
commit 9d3b62aa61
9 changed files with 153 additions and 41 deletions
+10
View File
@@ -36,6 +36,16 @@ class DatasetConfig:
video_backend: str = field(default_factory=get_safe_default_codec)
streaming: bool = False
def __post_init__(self) -> None:
if self.episodes is not None:
if any(ep < 0 for ep in self.episodes):
raise ValueError(
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
)
if len(self.episodes) != len(set(self.episodes)):
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
@dataclass
class WandBConfig:
+6 -4
View File
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import multiprocessing
import queue
import threading
@@ -22,6 +23,8 @@ import numpy as np
import PIL.Image
import torch
logger = logging.getLogger(__name__)
def safe_stop_image_writer(func):
def wrapper(*args, **kwargs):
@@ -31,7 +34,7 @@ def safe_stop_image_writer(func):
dataset = kwargs.get("dataset")
image_writer = getattr(dataset, "image_writer", None) if dataset else None
if image_writer is not None:
print("Waiting for image writer to terminate...")
logger.warning("Waiting for image writer to terminate...")
image_writer.stop()
raise e
@@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
PIL.Image.Image object.
Side Effects:
Prints an error message to the console if the image writing process
fails for any reason.
Logs an error message if the image writing process fails for any reason.
"""
try:
if isinstance(image, np.ndarray):
@@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath, compress_level=compress_level)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
logger.error("Error writing image %s: %s", fpath, e)
def worker_thread_loop(queue: queue.Queue):
+11 -6
View File
@@ -80,6 +80,8 @@ from lerobot.datasets.video_utils import (
)
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
CODEBASE_VERSION = "v3.0"
@@ -535,7 +537,10 @@ class LeRobotDatasetMetadata:
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
raise ValueError(
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
"Either remove video features from the features dict, or set 'use_videos=True'."
)
write_json(obj.info, obj.root / INFO_PATH)
obj.revision = None
obj.writer = None
@@ -1326,7 +1331,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
temp_path = future.result()
results[video_key] = temp_path
except Exception as exc:
logging.error(f"Video encoding failed for {video_key}: {exc}")
logger.error(f"Video encoding failed for {video_key}: {exc}")
raise exc
for video_key in self.meta.video_keys:
@@ -1365,7 +1370,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if end_episode is None:
end_episode = self.num_episodes
logging.info(
logger.info(
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
)
@@ -1375,7 +1380,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_df = pd.read_parquet(episode_df_path)
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding videos for episode {ep_idx}")
logger.info(f"Encoding videos for episode {ep_idx}")
if (
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
@@ -1605,7 +1610,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
logger.warning(
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
@@ -1771,7 +1776,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
if extra_keys:
logging.warning(
logger.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
+6 -4
View File
@@ -44,11 +44,11 @@ def create_initial_features(
return features
# Helper to filter state/action keys based on regex patterns.
def should_keep(key: str, patterns: tuple[str]) -> bool:
# Helper to filter state/action keys based on compiled regex patterns.
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
if patterns is None:
return True
return any(re.search(pat, key) for pat in patterns)
return any(pat.search(key) for pat in patterns)
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
@@ -89,6 +89,8 @@ def aggregate_pipeline_dataset_features(
Returns:
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
"""
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
all_features = pipeline.transform_features(initial_features)
# Intermediate storage for categorized and filtered features.
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
# 2. Apply filtering rules.
if is_image and not use_videos:
continue
if not is_image and not should_keep(key, patterns):
if not is_image and not should_keep(key, compiled_patterns):
continue
# 3. Add the feature to the appropriate group with a clean name.
+25
View File
@@ -13,10 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Iterator
import torch
logger = logging.getLogger(__name__)
class EpisodeAwareSampler:
def __init__(
@@ -39,13 +42,35 @@ class EpisodeAwareSampler:
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
if drop_n_last_frames < 0:
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(dataset_from_indices, dataset_to_indices, strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
ep_length = end_index - start_index
if drop_n_first_frames + drop_n_last_frames >= ep_length:
logger.warning(
"Episode %d has %d frames but drop_n_first_frames=%d and "
"drop_n_last_frames=%d removes all frames. Skipping.",
episode_idx,
ep_length,
drop_n_first_frames,
drop_n_last_frames,
)
continue
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
if not indices:
raise ValueError(
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
"All episodes were either filtered out or had too few frames."
)
self.indices = indices
self.shuffle = shuffle
+25 -23
View File
@@ -37,6 +37,8 @@ import torchvision
from datasets.features.features import register_feature
from PIL import Image
logger = logging.getLogger(__name__)
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
# Determines the order of preference for auto-selection when vcodec="auto" is used.
HW_ENCODERS = [
@@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]:
av.codec.Codec(codec_name, "w")
available.append(codec_name)
except Exception: # nosec B110
pass # nosec B110
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
return available
@@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str:
if vcodec not in VALID_VIDEO_CODECS:
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
if vcodec != "auto":
logging.info(f"Using video codec: {vcodec}")
logger.info(f"Using video codec: {vcodec}")
return vcodec
available = detect_available_hw_encoders()
for encoder in HW_ENCODERS:
if encoder in available:
logging.info(f"Auto-selected video codec: {encoder}")
logger.info(f"Auto-selected video codec: {encoder}")
return encoder
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
return "libsvtav1"
@@ -118,7 +120,7 @@ def get_safe_default_codec():
if importlib.util.find_spec("torchcodec"):
return "torchcodec"
else:
logging.warning(
logger.warning(
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
)
return "pyav"
@@ -208,7 +210,7 @@ def decode_video_frames_torchvision(
for frame in reader:
current_ts = frame["pts"]
if log_loaded_timestamps:
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
@@ -244,7 +246,7 @@ def decode_video_frames_torchvision(
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
logger.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
@@ -348,7 +350,7 @@ def decode_video_frames_torchcodec(
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}")
logger.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
@@ -374,7 +376,7 @@ def decode_video_frames_torchcodec(
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
logger.info(f"{closest_ts=}")
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
@@ -408,14 +410,14 @@ def encode_video_frames(
imgs_dir = Path(imgs_dir)
if video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
return
video_path.parent.mkdir(parents=True, exist_ok=True)
# Encoders/pixel formats incompatibility check
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
logging.warning(
logger.warning(
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
)
pix_fmt = "yuv420p"
@@ -508,7 +510,7 @@ def concatenate_video_files(
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
return
output_video_path.parent.mkdir(parents=True, exist_ok=True)
@@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread):
self.result_queue.put(("ok", None))
except Exception as e:
logging.error(f"Encoder thread error: {e}")
logger.error(f"Encoder thread error: {e}")
if container is not None:
with contextlib.suppress(Exception):
container.close()
@@ -819,7 +821,7 @@ class StreamingVideoEncoder:
count = self._dropped_frames[video_key]
# Log periodically to avoid spam (1st, then every 10th)
if count == 1 or count % 10 == 0:
logging.warning(
logger.warning(
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
)
@@ -841,7 +843,7 @@ class StreamingVideoEncoder:
# Report dropped frames
for video_key, count in self._dropped_frames.items():
if count > 0:
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
# Send sentinel to all queues
for video_key in self._frame_queues:
@@ -851,7 +853,7 @@ class StreamingVideoEncoder:
for video_key in self._threads:
self._threads[video_key].join(timeout=120)
if self._threads[video_key].is_alive():
logging.error(f"Encoder thread for {video_key} did not finish in time")
logger.error(f"Encoder thread for {video_key} did not finish in time")
self._stop_events[video_key].set()
self._threads[video_key].join(timeout=5)
results[video_key] = (self._video_paths[video_key], None)
@@ -863,7 +865,7 @@ class StreamingVideoEncoder:
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
results[video_key] = (self._video_paths[video_key], data)
except queue.Empty:
logging.error(f"No result from encoder thread for {video_key}")
logger.error(f"No result from encoder thread for {video_key}")
results[video_key] = (self._video_paths[video_key], None)
self._cleanup()
@@ -1071,13 +1073,13 @@ class VideoEncodingManager:
elif self.dataset.episodes_since_last_encoding > 0:
# Handle any remaining episodes that haven't been batch encoded
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
logger.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
logger.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
logger.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
@@ -1094,7 +1096,7 @@ class VideoEncodingManager:
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
logger.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
@@ -1105,8 +1107,8 @@ class VideoEncodingManager:
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
logger.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception
+38
View File
@@ -0,0 +1,38 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from lerobot.configs.default import DatasetConfig
def test_dataset_config_valid():
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2])
def test_dataset_config_negative_episodes():
with pytest.raises(ValueError, match="non-negative"):
DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2])
def test_dataset_config_duplicate_episodes():
with pytest.raises(ValueError, match="duplicates"):
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2])
def test_dataset_config_none_episodes_ok():
DatasetConfig(repo_id="user/repo", episodes=None)
def test_dataset_config_empty_episodes_ok():
DatasetConfig(repo_id="user/repo", episodes=[])
+4 -4
View File
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
write_image(image_array, fpath)
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
mock_logger.error.assert_called()
assert not fpath.exists()
finally:
writer.stop()
+28
View File
@@ -13,6 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pytest
import torch
from datasets import Dataset
@@ -106,3 +109,28 @@ def test_shuffle():
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
def test_negative_drop_last_frames_raises():
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
def test_all_episodes_dropped_raises():
# All episodes have 1 frame, drop_n_first_frames=1 removes all
with pytest.raises(ValueError, match="No valid frames remain"):
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
def test_partial_episode_drop_warns(caplog):
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
assert sampler.indices == [2, 3, 4, 5]
assert "Episode 0" in caplog.text