[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions
+36 -12
View File
@@ -19,7 +19,10 @@ from lerobot.common.datasets.utils import load_image_as_numpy
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
dataset_len: int,
min_num_samples: int = 100,
max_num_samples: int = 10_000,
power: float = 0.75,
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
@@ -43,14 +46,18 @@ def sample_indices(data_len: int) -> list[int]:
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
def auto_downsample_height_width(
img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300
):
_, height, width = img.shape
if max(width, height) < max_size_threshold:
# no downsampling needed
return img
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
downsample_factor = (
int(width / target_size) if width > height else int(height / target_size)
)
return img[:, ::downsample_factor, ::downsample_factor]
@@ -72,7 +79,9 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
def get_feature_stats(
array: np.ndarray, axis: tuple, keepdims: bool
) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
@@ -82,7 +91,9 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray], features: dict
) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
@@ -96,12 +107,15 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims
)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / 255.0, axis=0)
for k, v in ep_stats[key].items()
}
return ep_stats
@@ -116,14 +130,22 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
raise ValueError(
"Number of dimensions must be at least 1, and is 0 instead."
)
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
raise ValueError(
f"Shape of 'count' must be (1), but is {v.shape} instead."
)
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
raise ValueError(
f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead."
)
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_feature_stats(
stats_ft_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
@@ -152,7 +174,9 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
def aggregate_stats(
stats_list: list[dict[str, dict]],
) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.
+9 -3
View File
@@ -58,7 +58,9 @@ def resolve_delta_timestamps(
if key == "action" and cfg.action_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.action_delta_indices]
if key.startswith("observation.") and cfg.observation_delta_indices is not None:
delta_timestamps[key] = [i / ds_meta.fps for i in cfg.observation_delta_indices]
delta_timestamps[key] = [
i / ds_meta.fps for i in cfg.observation_delta_indices
]
if len(delta_timestamps) == 0:
delta_timestamps = None
@@ -79,7 +81,9 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
LeRobotDataset | MultiLeRobotDataset
"""
image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
ImageTransforms(cfg.dataset.image_transforms)
if cfg.dataset.image_transforms.enable
else None
)
if isinstance(cfg.dataset.repo_id, str):
@@ -113,6 +117,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
if cfg.dataset.use_imagenet_stats:
for key in dataset.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
dataset.meta.stats[key][stats_type] = torch.tensor(
stats, dtype=torch.float32
)
return dataset
+6 -2
View File
@@ -38,10 +38,14 @@ def safe_stop_image_writer(func):
return wrapper
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
def image_array_to_pil_image(
image_array: np.ndarray, range_check: bool = True
) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
raise ValueError(
f"The array has {image_array.ndim} dimensions, but 3 is expected for an image."
)
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
+81 -25
View File
@@ -108,7 +108,9 @@ class LeRobotDatasetMetadata:
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
self.episodes_stats = backward_compatible_episodes_stats(
self.stats, self.episodes
)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
@@ -238,7 +240,9 @@ class LeRobotDatasetMetadata:
Given a task in natural language, add it to the dictionary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
raise ValueError(
f"The task '{task}' already exists and can't be added twice."
)
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
@@ -281,7 +285,11 @@ class LeRobotDatasetMetadata:
write_episode(episode_dict, self.root)
self.episodes_stats[episode_index] = episode_stats
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
self.stats = (
aggregate_stats([self.stats, episode_stats])
if self.stats
else episode_stats
)
write_episode_stats(episode_index, episode_stats, self.root)
def update_video_info(self) -> None:
@@ -345,13 +353,17 @@ class LeRobotDatasetMetadata:
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
raise ValueError(
f"Feature names should not contain '/'. Found '/' in feature '{key}'."
)
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, features, use_videos
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -482,7 +494,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.video_backend = (
video_backend if video_backend else get_safe_default_codec()
)
self.delta_indices = None
# Unused attributes
@@ -495,28 +509,39 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
if self.episodes is not None and self.meta._version >= packaging.version.parse(
"v2.1"
):
episodes_stats = [
self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes
]
self.stats = aggregate_stats(episodes_stats)
# Load actual data
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
assert all(
(self.root / fpath).is_file()
for fpath in self.get_episodes_file_paths()
)
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
self.episode_data_index = get_episode_data_index(
self.meta.episodes, self.episodes
)
# Check timestamps
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
check_timestamps_sync(
timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s
)
# Setup delta_indices
if self.delta_timestamps is not None:
@@ -568,7 +593,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
if not hub_api.file_exists(
self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch
):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
@@ -576,8 +603,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
if tag_version:
with contextlib.suppress(RevisionNotFoundError):
hub_api.delete_tag(self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
hub_api.create_tag(self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.delete_tag(
self.repo_id, tag=CODEBASE_VERSION, repo_type="dataset"
)
hub_api.create_tag(
self.repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
def pull_from_repo(
self,
@@ -609,7 +640,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
episodes = (
self.episodes
if self.episodes is not None
else list(range(self.meta.total_episodes))
)
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
@@ -640,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
hf_dataset = datasets.Dataset.from_dict(
ft_dict, features=features, split="train"
)
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
@@ -726,7 +763,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
def _query_videos(
self, query_timestamps: dict[str, list[float]], ep_idx: int
) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@@ -735,7 +774,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
frames = decode_video_frames(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames.squeeze(0)
return item
@@ -789,7 +830,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
current_ep_idx = (
self.meta.total_episodes if episode_index is None else episode_index
)
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
@@ -887,7 +930,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["index"] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
# Add new tasks to the tasks dictionary
@@ -897,12 +942,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.meta.add_task(task)
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
episode_buffer["task_index"] = np.array(
[self.meta.get_task_index(task) for task in tasks]
)
for key, ft in self.features.items():
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
"image",
"video",
]:
continue
episode_buffer[key] = np.stack(episode_buffer[key])
@@ -944,7 +994,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = datasets.Dataset.from_dict(
episode_dict, features=self.hf_features, split="train"
)
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
@@ -1063,7 +1115,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
obj.video_backend = (
video_backend if video_backend is not None else get_safe_default_codec()
)
return obj
@@ -1088,7 +1142,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
self.tolerances_s = (
tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
)
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
+6 -2
View File
@@ -141,12 +141,16 @@ class SharpnessJitter(Transform):
return float(sharpness[0]), float(sharpness[1])
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
sharpness_factor = (
torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
)
return {"sharpness_factor": sharpness_factor}
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
@dataclass
+54 -17
View File
@@ -135,7 +135,9 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
raise NotImplementedError(
f"The value '{value}' of type '{type(value)}' is not supported."
)
return unflatten_dict(serialized_dict)
@@ -214,7 +216,10 @@ def write_task(task_index: int, task: dict, local_dir: Path):
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
tasks = {
item["task_index"]: item["task"]
for item in sorted(tasks, key=lambda x: x["task_index"])
}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
@@ -225,13 +230,19 @@ def write_episode(episode: dict, local_dir: Path):
def load_episodes(local_dir: Path) -> dict:
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
return {
item["episode_index"]: item
for item in sorted(episodes, key=lambda x: x["episode_index"])
}
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
episode_stats = {
"episode_index": episode_index,
"stats": serialize_dict(episode_stats),
}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
@@ -275,7 +286,9 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
items_dict[key] = [
x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]
]
return items_dict
@@ -328,7 +341,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
packaging.version.parse(version)
if not isinstance(version, packaging.version.Version)
else version
)
hub_versions = get_repo_versions(repo_id)
@@ -349,12 +364,16 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
v
for v in hub_versions
if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
logging.warning(
f"Revision {version} for {repo_id} not found, using version v{return_version}"
)
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
@@ -461,7 +480,9 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
episode_lengths = {
ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()
}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@@ -511,7 +532,9 @@ def check_timestamps_sync(
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
ignored_diffs = (
episode_data_index["to"][:-1] - 1
) # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
@@ -720,14 +743,18 @@ def validate_frame(frame: dict, features: dict):
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = validate_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(
actual_features, expected_features, optional_features
)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
error_message += validate_feature_dtype_and_shape(
name, features[name], frame[name]
)
if error_message:
raise ValueError(error_message)
@@ -750,7 +777,9 @@ def validate_features_presence(
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(
name: str, feature: dict, value: np.ndarray | PILImage.Image | str
):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
@@ -760,7 +789,9 @@ def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
raise NotImplementedError(
f"The feature dtype '{expected_dtype}' is not implemented yet."
)
def validate_feature_numpy_array(
@@ -782,13 +813,17 @@ def validate_feature_numpy_array(
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(
name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image
):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
if len(actual_shape) != 3 or (
actual_shape != (c, h, w) and actual_shape != (h, w, c)
):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
@@ -819,7 +854,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
@@ -35,22 +35,30 @@ def fix_dataset(repo_id: str) -> str:
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
lerobot_metadata = LeRobotDatasetMetadata(
repo_id, revision=V20, force_cache_sync=True
)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
meta_features = {
key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"
}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
raise ValueError(
f"In parquet not in info.json: {parquet_features - meta_features}"
)
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
logging.warning(
f"In info.json not in parquet: {meta_features - parquet_features}"
)
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
@@ -37,8 +37,16 @@ import logging
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
from lerobot.common.datasets.utils import (
EPISODES_STATS_PATH,
STATS_PATH,
load_stats,
write_info,
)
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats,
)
V20 = "v2.0"
V21 = "v2.1"
@@ -79,13 +87,21 @@ def convert_dataset(
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
repo_id=dataset.repo_id,
filename=STATS_PATH,
revision=branch,
repo_type="dataset",
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
path_in_repo=STATS_PATH,
repo_id=dataset.repo_id,
revision=branch,
repo_type="dataset",
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
hub_api.create_tag(
repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset"
)
if __name__ == "__main__":
+18 -5
View File
@@ -17,12 +17,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
get_feature_stats,
sample_indices,
)
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
def sample_episode_video_frames(
dataset: LeRobotDataset, episode_index: int, ft_key: str
) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
@@ -45,11 +51,14 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_data, axis=axes_to_reduce, keepdims=keepdims
)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v, axis=0)
for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
@@ -95,5 +104,9 @@ def check_aggregate_stats(
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
val,
reference_stats[key][stat],
rtol=rtol,
atol=atol,
err_msg=err_msg,
)
+3 -1
View File
@@ -65,7 +65,9 @@ def decode_video_frames(
if backend == "torchcodec":
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
elif backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise ValueError(f"Unsupported video backend: {backend}")