mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949)
* fix(edit dataset operation): fixing dataset tools CLI operation type specification * test(edit dataset operation): adding tests for dataset tools operation type specification * chore(format): running pre-commit * chore(backward compatibility): adding a type property in OperationConfig for backward compatibility Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
@@ -109,11 +109,14 @@ Using JSON config file:
|
|||||||
--config_path path/to/edit_config.json
|
--config_path path/to/edit_config.json
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.datasets.dataset_tools import (
|
from lerobot.datasets.dataset_tools import (
|
||||||
convert_image_to_video_dataset,
|
convert_image_to_video_dataset,
|
||||||
@@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeleteEpisodesConfig:
|
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||||
type: str = "delete_episodes"
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
return self.get_choice_name(self.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("delete_episodes")
|
||||||
|
@dataclass
|
||||||
|
class DeleteEpisodesConfig(OperationConfig):
|
||||||
episode_indices: list[int] | None = None
|
episode_indices: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("split")
|
||||||
@dataclass
|
@dataclass
|
||||||
class SplitConfig:
|
class SplitConfig(OperationConfig):
|
||||||
type: str = "split"
|
|
||||||
splits: dict[str, float | list[int]] | None = None
|
splits: dict[str, float | list[int]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("merge")
|
||||||
@dataclass
|
@dataclass
|
||||||
class MergeConfig:
|
class MergeConfig(OperationConfig):
|
||||||
type: str = "merge"
|
|
||||||
repo_ids: list[str] | None = None
|
repo_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("remove_feature")
|
||||||
@dataclass
|
@dataclass
|
||||||
class RemoveFeatureConfig:
|
class RemoveFeatureConfig(OperationConfig):
|
||||||
type: str = "remove_feature"
|
|
||||||
feature_names: list[str] | None = None
|
feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("modify_tasks")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModifyTasksConfig:
|
class ModifyTasksConfig(OperationConfig):
|
||||||
type: str = "modify_tasks"
|
|
||||||
new_task: str | None = None
|
new_task: str | None = None
|
||||||
episode_tasks: dict[str, str] | None = None
|
episode_tasks: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@OperationConfig.register_subclass("convert_image_to_video")
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConvertImageToVideoConfig:
|
class ConvertImageToVideoConfig(OperationConfig):
|
||||||
type: str = "convert_image_to_video"
|
|
||||||
output_dir: str | None = None
|
output_dir: str | None = None
|
||||||
vcodec: str = "libsvtav1"
|
vcodec: str = "libsvtav1"
|
||||||
pix_fmt: str = "yuv420p"
|
pix_fmt: str = "yuv420p"
|
||||||
@@ -177,14 +187,7 @@ class ConvertImageToVideoConfig:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EditDatasetConfig:
|
class EditDatasetConfig:
|
||||||
repo_id: str
|
repo_id: str
|
||||||
operation: (
|
operation: OperationConfig
|
||||||
DeleteEpisodesConfig
|
|
||||||
| SplitConfig
|
|
||||||
| MergeConfig
|
|
||||||
| RemoveFeatureConfig
|
|
||||||
| ModifyTasksConfig
|
|
||||||
| ConvertImageToVideoConfig
|
|
||||||
)
|
|
||||||
root: str | None = None
|
root: str | None = None
|
||||||
new_repo_id: str | None = None
|
new_repo_id: str | None = None
|
||||||
push_to_hub: bool = False
|
push_to_hub: bool = False
|
||||||
@@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
|||||||
elif operation_type == "convert_image_to_video":
|
elif operation_type == "convert_image_to_video":
|
||||||
handle_convert_image_to_video(cfg)
|
handle_convert_image_to_video(cfg)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
available = ", ".join(OperationConfig.get_known_choices())
|
||||||
f"Unknown operation type: {operation_type}\n"
|
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
|
||||||
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import draccus
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.scripts.lerobot_edit_dataset import (
|
||||||
|
ConvertImageToVideoConfig,
|
||||||
|
DeleteEpisodesConfig,
|
||||||
|
EditDatasetConfig,
|
||||||
|
MergeConfig,
|
||||||
|
ModifyTasksConfig,
|
||||||
|
OperationConfig,
|
||||||
|
RemoveFeatureConfig,
|
||||||
|
SplitConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
|
||||||
|
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
|
||||||
|
return draccus.parse(EditDatasetConfig, args=cli_args)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOperationTypeParsing:
|
||||||
|
"""Test that --operation.type correctly selects the right config subclass."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"type_name, expected_cls",
|
||||||
|
[
|
||||||
|
("delete_episodes", DeleteEpisodesConfig),
|
||||||
|
("split", SplitConfig),
|
||||||
|
("merge", MergeConfig),
|
||||||
|
("remove_feature", RemoveFeatureConfig),
|
||||||
|
("modify_tasks", ModifyTasksConfig),
|
||||||
|
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
|
||||||
|
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||||
|
assert isinstance(cfg.operation, expected_cls), (
|
||||||
|
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"type_name, expected_cls",
|
||||||
|
[
|
||||||
|
("delete_episodes", DeleteEpisodesConfig),
|
||||||
|
("split", SplitConfig),
|
||||||
|
("merge", MergeConfig),
|
||||||
|
("remove_feature", RemoveFeatureConfig),
|
||||||
|
("modify_tasks", ModifyTasksConfig),
|
||||||
|
("convert_image_to_video", ConvertImageToVideoConfig),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
|
||||||
|
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
|
||||||
|
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
|
||||||
|
assert resolved_name == type_name
|
||||||
Reference in New Issue
Block a user