#!/usr/bin/env python # Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import draccus import pytest from lerobot.scripts.lerobot_edit_dataset import ( ConvertImageToVideoConfig, DeleteEpisodesConfig, EditDatasetConfig, InfoConfig, MergeConfig, ModifyTasksConfig, OperationConfig, RemoveFeatureConfig, SplitConfig, ) def parse_cfg(cli_args: list[str]) -> EditDatasetConfig: """Helper to parse CLI args into an EditDatasetConfig via draccus.""" return draccus.parse(EditDatasetConfig, args=cli_args) class TestOperationTypeParsing: """Test that --operation.type correctly selects the right config subclass.""" @pytest.mark.parametrize( "type_name, expected_cls", [ ("delete_episodes", DeleteEpisodesConfig), ("split", SplitConfig), ("merge", MergeConfig), ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), ("info", InfoConfig), ], ) def test_operation_type_resolves_correct_class(self, type_name, expected_cls): cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) assert isinstance(cfg.operation, expected_cls), ( f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" ) @pytest.mark.parametrize( "type_name, expected_cls", [ ("delete_episodes", DeleteEpisodesConfig), ("split", SplitConfig), ("merge", MergeConfig), ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), ("info", InfoConfig), ], ) def test_get_choice_name_roundtrips(self, type_name, expected_cls): cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) assert resolved_name == type_name