fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949)

* fix(edit dataset operation): fixing dataset tools CLI operation type specification

* test(edit dataset operation): adding tests for dataset tools operation type specification

* chore(format): running pre-commit

* chore(backward compatibility): adding a type property in OperationConfig for backward compatibility

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
Caroline Pascal
2026-02-12 18:56:04 +01:00
committed by GitHub
parent 3615160d89
commit adebbcf090
2 changed files with 96 additions and 24 deletions
+25 -24
View File
@@ -109,11 +109,14 @@ Using JSON config file:
--config_path path/to/edit_config.json --config_path path/to/edit_config.json
""" """
import abc
import logging import logging
import shutil import shutil
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import draccus
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.datasets.dataset_tools import ( from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset, convert_image_to_video_dataset,
@@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging
@dataclass @dataclass
class DeleteEpisodesConfig: class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
type: str = "delete_episodes" @property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@OperationConfig.register_subclass("delete_episodes")
@dataclass
class DeleteEpisodesConfig(OperationConfig):
episode_indices: list[int] | None = None episode_indices: list[int] | None = None
@OperationConfig.register_subclass("split")
@dataclass @dataclass
class SplitConfig: class SplitConfig(OperationConfig):
type: str = "split"
splits: dict[str, float | list[int]] | None = None splits: dict[str, float | list[int]] | None = None
@OperationConfig.register_subclass("merge")
@dataclass @dataclass
class MergeConfig: class MergeConfig(OperationConfig):
type: str = "merge"
repo_ids: list[str] | None = None repo_ids: list[str] | None = None
@OperationConfig.register_subclass("remove_feature")
@dataclass @dataclass
class RemoveFeatureConfig: class RemoveFeatureConfig(OperationConfig):
type: str = "remove_feature"
feature_names: list[str] | None = None feature_names: list[str] | None = None
@OperationConfig.register_subclass("modify_tasks")
@dataclass @dataclass
class ModifyTasksConfig: class ModifyTasksConfig(OperationConfig):
type: str = "modify_tasks"
new_task: str | None = None new_task: str | None = None
episode_tasks: dict[str, str] | None = None episode_tasks: dict[str, str] | None = None
@OperationConfig.register_subclass("convert_image_to_video")
@dataclass @dataclass
class ConvertImageToVideoConfig: class ConvertImageToVideoConfig(OperationConfig):
type: str = "convert_image_to_video"
output_dir: str | None = None output_dir: str | None = None
vcodec: str = "libsvtav1" vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p" pix_fmt: str = "yuv420p"
@@ -177,14 +187,7 @@ class ConvertImageToVideoConfig:
@dataclass @dataclass
class EditDatasetConfig: class EditDatasetConfig:
repo_id: str repo_id: str
operation: ( operation: OperationConfig
DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
)
root: str | None = None root: str | None = None
new_repo_id: str | None = None new_repo_id: str | None = None
push_to_hub: bool = False push_to_hub: bool = False
@@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
elif operation_type == "convert_image_to_video": elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg) handle_convert_image_to_video(cfg)
else: else:
raise ValueError( available = ", ".join(OperationConfig.get_known_choices())
f"Unknown operation type: {operation_type}\n" raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
)
def main() -> None: def main() -> None:
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import draccus
import pytest
from lerobot.scripts.lerobot_edit_dataset import (
ConvertImageToVideoConfig,
DeleteEpisodesConfig,
EditDatasetConfig,
MergeConfig,
ModifyTasksConfig,
OperationConfig,
RemoveFeatureConfig,
SplitConfig,
)
def parse_cfg(cli_args: list[str]) -> EditDatasetConfig:
"""Helper to parse CLI args into an EditDatasetConfig via draccus."""
return draccus.parse(EditDatasetConfig, args=cli_args)
class TestOperationTypeParsing:
"""Test that --operation.type correctly selects the right config subclass."""
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
],
)
def test_operation_type_resolves_correct_class(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
assert isinstance(cfg.operation, expected_cls), (
f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}"
)
@pytest.mark.parametrize(
"type_name, expected_cls",
[
("delete_episodes", DeleteEpisodesConfig),
("split", SplitConfig),
("merge", MergeConfig),
("remove_feature", RemoveFeatureConfig),
("modify_tasks", ModifyTasksConfig),
("convert_image_to_video", ConvertImageToVideoConfig),
],
)
def test_get_choice_name_roundtrips(self, type_name, expected_cls):
cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name])
resolved_name = OperationConfig.get_choice_name(type(cfg.operation))
assert resolved_name == type_name