mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
chore(processor): add Step suffix to all processors (#1854)
* refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency * refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules * refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency * refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency * refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency * refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency * refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency * refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency * refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency * refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency * refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency * refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency * refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency * refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency * refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency * refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency * refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency * refactor(processor): update config file name in test for RenameProcessorStep consistency
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Tests for the TokenizerProcessor class.
|
||||
Tests for the TokenizerProcessorStep class.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.constants import OBS_LANGUAGE
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessor, TransitionKey
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def test_basic_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -127,7 +127,7 @@ def test_basic_tokenization_with_tokenizer_object():
|
||||
"""Test basic string tokenization functionality using tokenizer object directly."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -161,7 +161,7 @@ def test_list_of_strings_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -189,7 +189,7 @@ def test_custom_keys(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", task_key="instruction", max_length=5)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -215,7 +215,7 @@ def test_none_complementary_data(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data=None)
|
||||
|
||||
@@ -230,7 +230,7 @@ def test_missing_task_key(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data={"other_field": "some value"})
|
||||
|
||||
@@ -245,7 +245,7 @@ def test_none_task_value(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(complementary_data={"task": None})
|
||||
|
||||
@@ -260,7 +260,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Test with integer task
|
||||
transition = create_transition(complementary_data={"task": 123})
|
||||
@@ -279,7 +279,7 @@ def test_unsupported_task_type(mock_auto_tokenizer):
|
||||
def test_no_tokenizer_error():
|
||||
"""Test that ValueError is raised when neither tokenizer nor tokenizer_name is provided."""
|
||||
with pytest.raises(ValueError, match="Either 'tokenizer' or 'tokenizer_name' must be provided"):
|
||||
TokenizerProcessor()
|
||||
TokenizerProcessorStep()
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@@ -290,7 +290,7 @@ def test_invalid_tokenizer_name_error():
|
||||
mock_auto_tokenizer.from_pretrained.side_effect = Exception("Model not found")
|
||||
|
||||
with pytest.raises(Exception, match="Model not found"):
|
||||
TokenizerProcessor(tokenizer_name="invalid-tokenizer")
|
||||
TokenizerProcessorStep(tokenizer_name="invalid-tokenizer")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@@ -300,7 +300,7 @@ def test_get_config_with_tokenizer_name(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer",
|
||||
max_length=256,
|
||||
task_key="instruction",
|
||||
@@ -327,7 +327,7 @@ def test_get_config_with_tokenizer_object():
|
||||
"""Test configuration serialization when using tokenizer object."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer=mock_tokenizer,
|
||||
max_length=256,
|
||||
task_key="instruction",
|
||||
@@ -357,7 +357,7 @@ def test_state_dict_methods(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Should return empty dict
|
||||
state = processor.state_dict()
|
||||
@@ -374,7 +374,7 @@ def test_reset_method(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
# Should not raise error
|
||||
processor.reset()
|
||||
@@ -387,7 +387,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
@@ -424,7 +424,7 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
original_processor = TokenizerProcessor(
|
||||
original_processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
@@ -459,7 +459,9 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
"""Test saving and loading processor with tokenizer object using overrides."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
|
||||
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
|
||||
original_processor = TokenizerProcessorStep(
|
||||
tokenizer=mock_tokenizer, max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[original_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
@@ -500,14 +502,14 @@ def test_registry_functionality():
|
||||
|
||||
# Check that we can retrieve it
|
||||
retrieved_class = ProcessorStepRegistry.get("tokenizer_processor")
|
||||
assert retrieved_class is TokenizerProcessor
|
||||
assert retrieved_class is TokenizerProcessorStep
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_features_basic():
|
||||
"""Test basic feature contract functionality."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=128)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=128)
|
||||
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
@@ -538,7 +540,7 @@ def test_features_basic():
|
||||
def test_features_with_custom_max_length():
|
||||
"""Test feature contract with custom max_length."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=64)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=64)
|
||||
|
||||
input_features = {}
|
||||
output_features = processor.transform_features(input_features)
|
||||
@@ -558,7 +560,7 @@ def test_features_with_custom_max_length():
|
||||
def test_features_existing_features():
|
||||
"""Test feature contract when tokenized features already exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=256)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=256)
|
||||
|
||||
input_features = {
|
||||
f"{OBS_LANGUAGE}.tokens": PolicyFeature(type=FeatureType.LANGUAGE, shape=(100,)),
|
||||
@@ -595,7 +597,7 @@ def test_tokenization_parameters(mock_auto_tokenizer):
|
||||
tracking_tokenizer = TrackingMockTokenizer()
|
||||
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(
|
||||
processor = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer",
|
||||
max_length=16,
|
||||
padding="longest",
|
||||
@@ -627,7 +629,7 @@ def test_preserves_other_complementary_data(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer")
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer")
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -662,7 +664,7 @@ def test_deterministic_tokenization(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -690,7 +692,7 @@ def test_empty_string_task(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=8)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -714,7 +716,7 @@ def test_very_long_task(mock_auto_tokenizer):
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
|
||||
processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=5, truncation=True)
|
||||
|
||||
long_task = " ".join(["word"] * 100) # Very long task
|
||||
transition = create_transition(
|
||||
@@ -764,7 +766,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
mock_auto_tokenizer.from_pretrained.return_value = tracking_tokenizer
|
||||
|
||||
# Test left padding
|
||||
processor_left = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="left")
|
||||
processor_left = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=10, padding_side="left"
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -776,7 +780,9 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
assert tracking_tokenizer.padding_side_calls[-1] == "left"
|
||||
|
||||
# Test right padding (default)
|
||||
processor_right = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=10, padding_side="right")
|
||||
processor_right = TokenizerProcessorStep(
|
||||
tokenizer_name="test-tokenizer", max_length=10, padding_side="right"
|
||||
)
|
||||
|
||||
processor_right(transition)
|
||||
|
||||
@@ -787,7 +793,7 @@ def test_custom_padding_side(mock_auto_tokenizer):
|
||||
def test_device_detection_cpu():
|
||||
"""Test that tokenized tensors stay on CPU when other tensors are on CPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {"observation.state": torch.randn(10)} # CPU tensor
|
||||
@@ -811,7 +817,7 @@ def test_device_detection_cpu():
|
||||
def test_device_detection_cuda():
|
||||
"""Test that tokenized tensors are moved to CUDA when other tensors are on CUDA."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {"observation.state": torch.randn(10).cuda()} # CUDA tensor
|
||||
@@ -836,7 +842,7 @@ def test_device_detection_cuda():
|
||||
def test_device_detection_multi_gpu():
|
||||
"""Test that tokenized tensors match device in multi-GPU setup."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Test with tensors on cuda:1
|
||||
device = torch.device("cuda:1")
|
||||
@@ -860,7 +866,7 @@ def test_device_detection_multi_gpu():
|
||||
def test_device_detection_no_tensors():
|
||||
"""Test that tokenized tensors stay on CPU when no other tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with no tensors
|
||||
transition = create_transition(
|
||||
@@ -882,7 +888,7 @@ def test_device_detection_no_tensors():
|
||||
def test_device_detection_mixed_devices():
|
||||
"""Test device detection when tensors are on different devices (uses first found)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# Create transition with mixed devices
|
||||
@@ -910,7 +916,7 @@ def test_device_detection_mixed_devices():
|
||||
def test_device_detection_from_action():
|
||||
"""Test that device is detected from action tensor when no observation tensors exist."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with action on CUDA but no observation tensors
|
||||
observation = {"metadata": {"key": "value"}} # No tensors in observation
|
||||
@@ -933,7 +939,7 @@ def test_device_detection_from_action():
|
||||
def test_device_detection_preserves_dtype():
|
||||
"""Test that device detection doesn't affect dtype of tokenized tensors."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with float tensor (to test dtype isn't affected)
|
||||
observation = {"observation.state": torch.randn(10, dtype=torch.float16)}
|
||||
@@ -953,15 +959,15 @@ def test_device_detection_preserves_dtype():
|
||||
@require_package("transformers")
|
||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||
def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
"""Test that TokenizerProcessor works correctly with DeviceProcessor in pipeline."""
|
||||
from lerobot.processor import DeviceProcessor
|
||||
"""Test that TokenizerProcessorStep works correctly with DeviceProcessorStep in pipeline."""
|
||||
from lerobot.processor import DeviceProcessorStep
|
||||
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
# Create pipeline with TokenizerProcessor then DeviceProcessor
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessor(device="cuda:0")
|
||||
# Create pipeline with TokenizerProcessorStep then DeviceProcessorStep
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessorStep(device="cuda:0")
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
@@ -975,7 +981,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
|
||||
result = robot_processor(transition)
|
||||
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessor)
|
||||
# All tensors should end up on CUDA (moved by DeviceProcessorStep)
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert result[TransitionKey.ACTION].device.type == "cuda"
|
||||
|
||||
@@ -991,7 +997,7 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
def test_simulated_accelerate_scenario():
|
||||
"""Test scenario simulating Accelerate with data already on GPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=10)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Simulate Accelerate scenario: batch already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
Reference in New Issue
Block a user