diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 3fb0c6c4e..7f481b9ca 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -36,6 +36,16 @@ class DatasetConfig: video_backend: str = field(default_factory=get_safe_default_codec) streaming: bool = False + def __post_init__(self) -> None: + if self.episodes is not None: + if any(ep < 0 for ep in self.episodes): + raise ValueError( + f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}" + ) + if len(self.episodes) != len(set(self.episodes)): + duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1}) + raise ValueError(f"Episode indices contain duplicates: {duplicates}") + @dataclass class WandBConfig: diff --git a/src/lerobot/datasets/image_writer.py b/src/lerobot/datasets/image_writer.py index 23bc2efb8..9f40394de 100644 --- a/src/lerobot/datasets/image_writer.py +++ b/src/lerobot/datasets/image_writer.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import multiprocessing import queue import threading @@ -22,6 +23,8 @@ import numpy as np import PIL.Image import torch +logger = logging.getLogger(__name__) + def safe_stop_image_writer(func): def wrapper(*args, **kwargs): @@ -31,7 +34,7 @@ def safe_stop_image_writer(func): dataset = kwargs.get("dataset") image_writer = getattr(dataset, "image_writer", None) if dataset else None if image_writer is not None: - print("Waiting for image writer to terminate...") + logger.warning("Waiting for image writer to terminate...") image_writer.stop() raise e @@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level PIL.Image.Image object. Side Effects: - Prints an error message to the console if the image writing process - fails for any reason. + Logs an error message if the image writing process fails for any reason. """ try: if isinstance(image, np.ndarray): @@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level raise TypeError(f"Unsupported image type: {type(image)}") img.save(fpath, compress_level=compress_level) except Exception as e: - print(f"Error writing image {fpath}: {e}") + logger.error("Error writing image %s: %s", fpath, e) def worker_thread_loop(queue: queue.Queue): diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 11c10f493..5d1b5d042 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -80,6 +80,8 @@ from lerobot.datasets.video_utils import ( ) from lerobot.utils.constants import HF_LEROBOT_HOME +logger = logging.getLogger(__name__) + CODEBASE_VERSION = "v3.0" @@ -535,7 +537,10 @@ class LeRobotDatasetMetadata: video_files_size_in_mb, ) if len(obj.video_keys) > 0 and not use_videos: - raise ValueError() + raise ValueError( + f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. " + "Either remove video features from the features dict, or set 'use_videos=True'." + ) write_json(obj.info, obj.root / INFO_PATH) obj.revision = None obj.writer = None @@ -1326,7 +1331,7 @@ class LeRobotDataset(torch.utils.data.Dataset): temp_path = future.result() results[video_key] = temp_path except Exception as exc: - logging.error(f"Video encoding failed for {video_key}: {exc}") + logger.error(f"Video encoding failed for {video_key}: {exc}") raise exc for video_key in self.meta.video_keys: @@ -1365,7 +1370,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if end_episode is None: end_episode = self.num_episodes - logging.info( + logger.info( f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}" ) @@ -1375,7 +1380,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_df = pd.read_parquet(episode_df_path) for ep_idx in range(start_episode, end_episode): - logging.info(f"Encoding videos for episode {ep_idx}") + logger.info(f"Encoding videos for episode {ep_idx}") if ( self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx @@ -1605,7 +1610,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: if isinstance(self.image_writer, AsyncImageWriter): - logging.warning( + logger.warning( "You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset." ) @@ -1771,7 +1776,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True): extra_keys = set(ds.features).difference(intersection_features) if extra_keys: - logging.warning( + logger.warning( f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the " "other datasets." ) diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index f824eb9bc..fe8cabbeb 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -44,11 +44,11 @@ def create_initial_features( return features -# Helper to filter state/action keys based on regex patterns. -def should_keep(key: str, patterns: tuple[str]) -> bool: +# Helper to filter state/action keys based on compiled regex patterns. +def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool: if patterns is None: return True - return any(re.search(pat, key) for pat in patterns) + return any(pat.search(key) for pat in patterns) def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str: @@ -89,6 +89,8 @@ def aggregate_pipeline_dataset_features( Returns: A dictionary of features formatted for a Hugging Face LeRobot Dataset. """ + compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None + all_features = pipeline.transform_features(initial_features) # Intermediate storage for categorized and filtered features. @@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features( # 2. Apply filtering rules. if is_image and not use_videos: continue - if not is_image and not should_keep(key, patterns): + if not is_image and not should_keep(key, compiled_patterns): continue # 3. Add the feature to the appropriate group with a clean name. diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index d0bb20c27..2bf7ab922 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -13,10 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Iterator import torch +logger = logging.getLogger(__name__) + class EpisodeAwareSampler: def __init__( @@ -39,13 +42,35 @@ class EpisodeAwareSampler: drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. """ + if drop_n_first_frames < 0: + raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") + if drop_n_last_frames < 0: + raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") + indices = [] for episode_idx, (start_index, end_index) in enumerate( zip(dataset_from_indices, dataset_to_indices, strict=True) ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: + ep_length = end_index - start_index + if drop_n_first_frames + drop_n_last_frames >= ep_length: + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + ep_length, + drop_n_first_frames, + drop_n_last_frames, + ) + continue indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + if not indices: + raise ValueError( + "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " + "All episodes were either filtered out or had too few frames." + ) + self.indices = indices self.shuffle = shuffle diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 8c8494b87..e465b79b4 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -37,6 +37,8 @@ import torchvision from datasets.features.features import register_feature from PIL import Image +logger = logging.getLogger(__name__) + # List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. # Determines the order of preference for auto-selection when vcodec="auto" is used. HW_ENCODERS = [ @@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]: av.codec.Codec(codec_name, "w") available.append(codec_name) except Exception: # nosec B110 - pass # nosec B110 + logger.debug("HW encoder '%s' not available", codec_name) # nosec B110 return available @@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str: if vcodec not in VALID_VIDEO_CODECS: raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") if vcodec != "auto": - logging.info(f"Using video codec: {vcodec}") + logger.info(f"Using video codec: {vcodec}") return vcodec available = detect_available_hw_encoders() for encoder in HW_ENCODERS: if encoder in available: - logging.info(f"Auto-selected video codec: {encoder}") + logger.info(f"Auto-selected video codec: {encoder}") return encoder - logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") + logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") return "libsvtav1" @@ -118,7 +120,7 @@ def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): return "torchcodec" else: - logging.warning( + logger.warning( "'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder" ) return "pyav" @@ -208,7 +210,7 @@ def decode_video_frames_torchvision( for frame in reader: current_ts = frame["pts"] if log_loaded_timestamps: - logging.info(f"frame loaded at timestamp={current_ts:.4f}") + logger.info(f"frame loaded at timestamp={current_ts:.4f}") loaded_frames.append(frame["data"]) loaded_ts.append(current_ts) if current_ts >= last_ts: @@ -244,7 +246,7 @@ def decode_video_frames_torchvision( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to the pytorch format which is float32 in [0,1] range (and channel first) closest_frames = closest_frames.type(torch.float32) / 255 @@ -348,7 +350,7 @@ def decode_video_frames_torchcodec( loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: - logging.info(f"Frame loaded at timestamp={pts:.4f}") + logger.info(f"Frame loaded at timestamp={pts:.4f}") query_ts = torch.tensor(timestamps) loaded_ts = torch.tensor(loaded_ts) @@ -374,7 +376,7 @@ def decode_video_frames_torchcodec( closest_ts = loaded_ts[argmin_] if log_loaded_timestamps: - logging.info(f"{closest_ts=}") + logger.info(f"{closest_ts=}") # convert to float32 in [0,1] range closest_frames = (closest_frames / 255.0).type(torch.float32) @@ -408,14 +410,14 @@ def encode_video_frames( imgs_dir = Path(imgs_dir) if video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {video_path}. Skipping encoding.") + logger.warning(f"Video file already exists: {video_path}. Skipping encoding.") return video_path.parent.mkdir(parents=True, exist_ok=True) # Encoders/pixel formats incompatibility check if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p": - logging.warning( + logger.warning( f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'" ) pix_fmt = "yuv420p" @@ -508,7 +510,7 @@ def concatenate_video_files( output_video_path = Path(output_video_path) if output_video_path.exists() and not overwrite: - logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") + logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.") return output_video_path.parent.mkdir(parents=True, exist_ok=True) @@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread): self.result_queue.put(("ok", None)) except Exception as e: - logging.error(f"Encoder thread error: {e}") + logger.error(f"Encoder thread error: {e}") if container is not None: with contextlib.suppress(Exception): container.close() @@ -819,7 +821,7 @@ class StreamingVideoEncoder: count = self._dropped_frames[video_key] # Log periodically to avoid spam (1st, then every 10th) if count == 1 or count % 10 == 0: - logging.warning( + logger.warning( f"Encoder queue full for {video_key}, dropped {count} frame(s). " f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize." ) @@ -841,7 +843,7 @@ class StreamingVideoEncoder: # Report dropped frames for video_key, count in self._dropped_frames.items(): if count > 0: - logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") + logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") # Send sentinel to all queues for video_key in self._frame_queues: @@ -851,7 +853,7 @@ class StreamingVideoEncoder: for video_key in self._threads: self._threads[video_key].join(timeout=120) if self._threads[video_key].is_alive(): - logging.error(f"Encoder thread for {video_key} did not finish in time") + logger.error(f"Encoder thread for {video_key} did not finish in time") self._stop_events[video_key].set() self._threads[video_key].join(timeout=5) results[video_key] = (self._video_paths[video_key], None) @@ -863,7 +865,7 @@ class StreamingVideoEncoder: raise RuntimeError(f"Encoder thread for {video_key} failed: {data}") results[video_key] = (self._video_paths[video_key], data) except queue.Empty: - logging.error(f"No result from encoder thread for {video_key}") + logger.error(f"No result from encoder thread for {video_key}") results[video_key] = (self._video_paths[video_key], None) self._cleanup() @@ -1071,13 +1073,13 @@ class VideoEncodingManager: elif self.dataset.episodes_since_last_encoding > 0: # Handle any remaining episodes that haven't been batch encoded if exc_type is not None: - logging.info("Exception occurred. Encoding remaining episodes before exit...") + logger.info("Exception occurred. Encoding remaining episodes before exit...") else: - logging.info("Recording stopped. Encoding remaining episodes...") + logger.info("Recording stopped. Encoding remaining episodes...") start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding end_ep = self.dataset.num_episodes - logging.info( + logger.info( f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " f"from episode {start_ep} to {end_ep - 1}" ) @@ -1094,7 +1096,7 @@ class VideoEncodingManager: episode_index=interrupted_episode_index, image_key=key, frame_index=0 ).parent if img_dir.exists(): - logging.debug( + logger.debug( f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" ) shutil.rmtree(img_dir) @@ -1105,8 +1107,8 @@ class VideoEncodingManager: png_files = list(img_dir.rglob("*.png")) if len(png_files) == 0: shutil.rmtree(img_dir) - logging.debug("Cleaned up empty images directory") + logger.debug("Cleaned up empty images directory") else: - logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") return False # Don't suppress the original exception diff --git a/tests/configs/test_default.py b/tests/configs/test_default.py new file mode 100644 index 000000000..238b8bacd --- /dev/null +++ b/tests/configs/test_default.py @@ -0,0 +1,38 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from lerobot.configs.default import DatasetConfig + + +def test_dataset_config_valid(): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2]) + + +def test_dataset_config_negative_episodes(): + with pytest.raises(ValueError, match="non-negative"): + DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2]) + + +def test_dataset_config_duplicate_episodes(): + with pytest.raises(ValueError, match="duplicates"): + DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2]) + + +def test_dataset_config_none_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=None) + + +def test_dataset_config_empty_episodes_ok(): + DatasetConfig(repo_id="user/repo", episodes=[]) diff --git a/tests/datasets/test_image_writer.py b/tests/datasets/test_image_writer.py index 99c8b24fc..e02755171 100644 --- a/tests/datasets/test_image_writer.py +++ b/tests/datasets/test_image_writer.py @@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory): def test_write_image_exception(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: write_image(image_array, fpath) - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() @@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path): image_array = "invalid data" fpath = tmp_path / DUMMY_IMAGE fpath.parent.mkdir(parents=True, exist_ok=True) - with patch("builtins.print") as mock_print: + with patch("lerobot.datasets.image_writer.logger") as mock_logger: writer.save_image(image_array, fpath) writer.wait_until_done() - mock_print.assert_called() + mock_logger.error.assert_called() assert not fpath.exists() finally: writer.stop() diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index e5b35e426..a5d463349 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -13,6 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import pytest import torch from datasets import Dataset @@ -106,3 +109,28 @@ def test_shuffle(): assert sampler.indices == [0, 1, 2, 3, 4, 5] assert len(sampler) == 6 assert set(sampler) == {0, 1, 2, 3, 4, 5} + + +def test_negative_drop_first_frames_raises(): + with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_first_frames=-1) + + +def test_negative_drop_last_frames_raises(): + with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"): + EpisodeAwareSampler([0], [10], drop_n_last_frames=-1) + + +def test_all_episodes_dropped_raises(): + # All episodes have 1 frame, drop_n_first_frames=1 removes all + with pytest.raises(ValueError, match="No valid frames remain"): + EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1) + + +def test_partial_episode_drop_warns(caplog): + # Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept) + with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"): + sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1) + # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 + assert sampler.indices == [2, 3, 4, 5] + assert "Episode 0" in caplog.text