mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
chore(dataset): basic house-keeping (#3170)
This commit is contained in:
@@ -36,6 +36,16 @@ class DatasetConfig:
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.episodes is not None:
|
||||
if any(ep < 0 for ep in self.episodes):
|
||||
raise ValueError(
|
||||
f"Episode indices must be non-negative, got: {[ep for ep in self.episodes if ep < 0]}"
|
||||
)
|
||||
if len(self.episodes) != len(set(self.episodes)):
|
||||
duplicates = sorted({ep for ep in self.episodes if self.episodes.count(ep) > 1})
|
||||
raise ValueError(f"Episode indices contain duplicates: {duplicates}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WandBConfig:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import multiprocessing
|
||||
import queue
|
||||
import threading
|
||||
@@ -22,6 +23,8 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
@@ -31,7 +34,7 @@ def safe_stop_image_writer(func):
|
||||
dataset = kwargs.get("dataset")
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
logger.warning("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
@@ -89,8 +92,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
PIL.Image.Image object.
|
||||
|
||||
Side Effects:
|
||||
Prints an error message to the console if the image writing process
|
||||
fails for any reason.
|
||||
Logs an error message if the image writing process fails for any reason.
|
||||
"""
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
@@ -101,7 +103,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
|
||||
@@ -80,6 +80,8 @@ from lerobot.datasets.video_utils import (
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
|
||||
@@ -535,7 +537,10 @@ class LeRobotDatasetMetadata:
|
||||
video_files_size_in_mb,
|
||||
)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
raise ValueError(
|
||||
f"Features contain video keys {obj.video_keys}, but 'use_videos' is set to False. "
|
||||
"Either remove video features from the features dict, or set 'use_videos=True'."
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.revision = None
|
||||
obj.writer = None
|
||||
@@ -1326,7 +1331,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
logger.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
@@ -1365,7 +1370,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if end_episode is None:
|
||||
end_episode = self.num_episodes
|
||||
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_episode} to {end_episode - 1}"
|
||||
)
|
||||
|
||||
@@ -1375,7 +1380,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_df = pd.read_parquet(episode_df_path)
|
||||
|
||||
for ep_idx in range(start_episode, end_episode):
|
||||
logging.info(f"Encoding videos for episode {ep_idx}")
|
||||
logger.info(f"Encoding videos for episode {ep_idx}")
|
||||
|
||||
if (
|
||||
self.meta.episodes[ep_idx]["data/chunk_index"] != chunk_idx
|
||||
@@ -1605,7 +1610,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
|
||||
if isinstance(self.image_writer, AsyncImageWriter):
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
|
||||
)
|
||||
|
||||
@@ -1771,7 +1776,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
|
||||
@@ -44,11 +44,11 @@ def create_initial_features(
|
||||
return features
|
||||
|
||||
|
||||
# Helper to filter state/action keys based on regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[str]) -> bool:
|
||||
# Helper to filter state/action keys based on compiled regex patterns.
|
||||
def should_keep(key: str, patterns: tuple[re.Pattern] | None) -> bool:
|
||||
if patterns is None:
|
||||
return True
|
||||
return any(re.search(pat, key) for pat in patterns)
|
||||
return any(pat.search(key) for pat in patterns)
|
||||
|
||||
|
||||
def strip_prefix(key: str, prefixes_to_strip: tuple[str]) -> str:
|
||||
@@ -89,6 +89,8 @@ def aggregate_pipeline_dataset_features(
|
||||
Returns:
|
||||
A dictionary of features formatted for a Hugging Face LeRobot Dataset.
|
||||
"""
|
||||
compiled_patterns = tuple(re.compile(p) for p in patterns) if patterns is not None else None
|
||||
|
||||
all_features = pipeline.transform_features(initial_features)
|
||||
|
||||
# Intermediate storage for categorized and filtered features.
|
||||
@@ -120,7 +122,7 @@ def aggregate_pipeline_dataset_features(
|
||||
# 2. Apply filtering rules.
|
||||
if is_image and not use_videos:
|
||||
continue
|
||||
if not is_image and not should_keep(key, patterns):
|
||||
if not is_image and not should_keep(key, compiled_patterns):
|
||||
continue
|
||||
|
||||
# 3. Add the feature to the appropriate group with a clean name.
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
@@ -39,13 +42,35 @@ class EpisodeAwareSampler:
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
if drop_n_last_frames < 0:
|
||||
raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}")
|
||||
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(dataset_from_indices, dataset_to_indices, strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
ep_length = end_index - start_index
|
||||
if drop_n_first_frames + drop_n_last_frames >= ep_length:
|
||||
logger.warning(
|
||||
"Episode %d has %d frames but drop_n_first_frames=%d and "
|
||||
"drop_n_last_frames=%d removes all frames. Skipping.",
|
||||
episode_idx,
|
||||
ep_length,
|
||||
drop_n_first_frames,
|
||||
drop_n_last_frames,
|
||||
)
|
||||
continue
|
||||
indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames))
|
||||
|
||||
if not indices:
|
||||
raise ValueError(
|
||||
"No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. "
|
||||
"All episodes were either filtered out or had too few frames."
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
@@ -94,7 +96,7 @@ def detect_available_hw_encoders() -> list[str]:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
@@ -103,14 +105,14 @@ def resolve_vcodec(vcodec: str) -> str:
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
@@ -118,7 +120,7 @@ def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
@@ -208,7 +210,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logger.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -244,7 +246,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -348,7 +350,7 @@ def decode_video_frames_torchcodec(
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
logger.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
@@ -374,7 +376,7 @@ def decode_video_frames_torchcodec(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logger.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
@@ -408,14 +410,14 @@ def encode_video_frames(
|
||||
imgs_dir = Path(imgs_dir)
|
||||
|
||||
if video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
logger.warning(f"Video file already exists: {video_path}. Skipping encoding.")
|
||||
return
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
@@ -508,7 +510,7 @@ def concatenate_video_files(
|
||||
output_video_path = Path(output_video_path)
|
||||
|
||||
if output_video_path.exists() and not overwrite:
|
||||
logging.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
logger.warning(f"Video file already exists: {output_video_path}. Skipping concatenation.")
|
||||
return
|
||||
|
||||
output_video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -693,7 +695,7 @@ class _CameraEncoderThread(threading.Thread):
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
logger.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
@@ -819,7 +821,7 @@ class StreamingVideoEncoder:
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
@@ -841,7 +843,7 @@ class StreamingVideoEncoder:
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
logger.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
@@ -851,7 +853,7 @@ class StreamingVideoEncoder:
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
logger.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
@@ -863,7 +865,7 @@ class StreamingVideoEncoder:
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
logger.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
@@ -1071,13 +1073,13 @@ class VideoEncodingManager:
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
logger.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
logging.info("Recording stopped. Encoding remaining episodes...")
|
||||
logger.info("Recording stopped. Encoding remaining episodes...")
|
||||
|
||||
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
|
||||
end_ep = self.dataset.num_episodes
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
|
||||
f"from episode {start_ep} to {end_ep - 1}"
|
||||
)
|
||||
@@ -1094,7 +1096,7 @@ class VideoEncodingManager:
|
||||
episode_index=interrupted_episode_index, image_key=key, frame_index=0
|
||||
).parent
|
||||
if img_dir.exists():
|
||||
logging.debug(
|
||||
logger.debug(
|
||||
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
@@ -1105,8 +1107,8 @@ class VideoEncodingManager:
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
logger.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
logger.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
|
||||
|
||||
def test_dataset_config_valid():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_negative_episodes():
|
||||
with pytest.raises(ValueError, match="non-negative"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, -1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_duplicate_episodes():
|
||||
with pytest.raises(ValueError, match="duplicates"):
|
||||
DatasetConfig(repo_id="user/repo", episodes=[0, 1, 1, 2])
|
||||
|
||||
|
||||
def test_dataset_config_none_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=None)
|
||||
|
||||
|
||||
def test_dataset_config_empty_episodes_ok():
|
||||
DatasetConfig(repo_id="user/repo", episodes=[])
|
||||
@@ -142,9 +142,9 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
@@ -243,10 +243,10 @@ def test_save_image_invalid_data(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
with patch("lerobot.datasets.image_writer.logger") as mock_logger:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
mock_logger.error.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
@@ -13,6 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -106,3 +109,28 @@ def test_shuffle():
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
|
||||
def test_negative_drop_last_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_last_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_last_frames=-1)
|
||||
|
||||
|
||||
def test_all_episodes_dropped_raises():
|
||||
# All episodes have 1 frame, drop_n_first_frames=1 removes all
|
||||
with pytest.raises(ValueError, match="No valid frames remain"):
|
||||
EpisodeAwareSampler([0, 1, 2], [1, 2, 3], drop_n_first_frames=1)
|
||||
|
||||
|
||||
def test_partial_episode_drop_warns(caplog):
|
||||
# Episode 0: 1 frame (dropped), Episode 1: 5 frames (kept)
|
||||
with caplog.at_level(logging.WARNING, logger="lerobot.datasets.sampler"):
|
||||
sampler = EpisodeAwareSampler([0, 1], [1, 6], drop_n_first_frames=1)
|
||||
# Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5
|
||||
assert sampler.indices == [2, 3, 4, 5]
|
||||
assert "Episode 0" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user