Remove dataset consolidate (#752)

This commit is contained in:
Simon Alibert
2025-02-19 16:02:54 +01:00
committed by GitHub
parent 6fe42a72db
commit 969ef745a2
6 changed files with 93 additions and 128 deletions
+40 -12
View File
@@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace):
return vars(self).keys()
def check_frame_features(frame: dict, features: dict):
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = check_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += check_feature_string("task", frame["task"])
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def check_features_presence(
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
@@ -679,20 +679,22 @@ def check_features_presence(
return error_message
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return check_feature_image_or_video(name, expected_shape, value)
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return check_feature_string(name, value)
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li
return error_message
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np
return error_message
def check_feature_string(name: str, value: str):
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)