fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949)

* fix(edit dataset operation): fixing dataset tools CLI operation type specification

* test(edit dataset operation): adding tests for dataset tools operation type specification

* chore(format): running pre-commit

* chore(backward compatibility): adding a type property in OperationConfig for backward compatibility

Signed-off-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
Caroline Pascal
2026-02-12 18:56:04 +01:00
committed by GitHub
parent 3615160d89
commit adebbcf090
2 changed files with 96 additions and 24 deletions
+25 -24
View File
@@ -109,11 +109,14 @@ Using JSON config file:
--config_path path/to/edit_config.json
"""
import abc
import logging
import shutil
from dataclasses import dataclass
from pathlib import Path
import draccus
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
convert_image_to_video_dataset,
@@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging
@dataclass
class DeleteEpisodesConfig:
type: str = "delete_episodes"
class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@OperationConfig.register_subclass("delete_episodes")
@dataclass
class DeleteEpisodesConfig(OperationConfig):
episode_indices: list[int] | None = None
@OperationConfig.register_subclass("split")
@dataclass
class SplitConfig:
type: str = "split"
class SplitConfig(OperationConfig):
splits: dict[str, float | list[int]] | None = None
@OperationConfig.register_subclass("merge")
@dataclass
class MergeConfig:
type: str = "merge"
class MergeConfig(OperationConfig):
repo_ids: list[str] | None = None
@OperationConfig.register_subclass("remove_feature")
@dataclass
class RemoveFeatureConfig:
type: str = "remove_feature"
class RemoveFeatureConfig(OperationConfig):
feature_names: list[str] | None = None
@OperationConfig.register_subclass("modify_tasks")
@dataclass
class ModifyTasksConfig:
type: str = "modify_tasks"
class ModifyTasksConfig(OperationConfig):
new_task: str | None = None
episode_tasks: dict[str, str] | None = None
@OperationConfig.register_subclass("convert_image_to_video")
@dataclass
class ConvertImageToVideoConfig:
type: str = "convert_image_to_video"
class ConvertImageToVideoConfig(OperationConfig):
output_dir: str | None = None
vcodec: str = "libsvtav1"
pix_fmt: str = "yuv420p"
@@ -177,14 +187,7 @@ class ConvertImageToVideoConfig:
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig
| SplitConfig
| MergeConfig
| RemoveFeatureConfig
| ModifyTasksConfig
| ConvertImageToVideoConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
@@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video"
)
available = ", ".join(OperationConfig.get_known_choices())
raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}")
def main() -> None: