mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
feat(dataset): add subtask support (#2860)
* add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs
This commit is contained in:
@@ -27,7 +27,14 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGE,
|
||||
OBS_LANGUAGE,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario():
|
||||
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
|
||||
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
|
||||
assert attention_mask.shape == (10,)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_subtask method
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_missing_key():
|
||||
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task"}, # No "subtask" key
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_none_value():
|
||||
"""Test get_subtask returns None when subtask value is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": None},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_none_complementary_data():
|
||||
"""Test get_subtask returns None when complementary_data is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data=None, # No complementary data
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_string():
|
||||
"""Test get_subtask returns list with single string when subtask is a string."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up the cube"},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == ["pick up the cube"]
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_list_of_strings():
|
||||
"""Test get_subtask returns the list when subtask is already a list of strings."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
subtask_list = ["pick up", "move to target", "place down"]
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": subtask_list},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == subtask_list
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_integer():
|
||||
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": 123},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_mixed_list():
|
||||
"""Test get_subtask returns None when subtask is a list with mixed types."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_dict():
|
||||
"""Test get_subtask returns None when subtask is a dictionary."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": {"key": "value"}},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_empty_string():
|
||||
"""Test get_subtask with empty string returns list with empty string."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ""},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == [""]
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_empty_list():
|
||||
"""Test get_subtask with empty list returns empty list."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": []},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for subtask tokenization in observation method
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_when_present():
|
||||
"""Test that subtask is tokenized and added to observation when present."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up the red cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check token structure
|
||||
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert isinstance(subtask_tokens, torch.Tensor)
|
||||
assert isinstance(subtask_attention_mask, torch.Tensor)
|
||||
assert subtask_tokens.shape == (8,)
|
||||
assert subtask_attention_mask.shape == (8,)
|
||||
assert subtask_attention_mask.dtype == torch.bool
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_not_added_when_none():
|
||||
"""Test that subtask tokens are NOT added to observation when subtask is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task"}, # No subtask
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were NOT added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
# But main task tokens should still be present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
|
||||
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": None},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were NOT added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_list_of_strings():
|
||||
"""Test subtask tokenization with list of strings."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ["pick up", "place down"]},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check token structure for batch
|
||||
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8
|
||||
assert subtask_attention_mask.shape == (2, 8)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_device_cpu():
|
||||
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
|
||||
action = torch.randn(5) # CPU tensor
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens are on CPU
|
||||
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
assert subtask_tokens.device.type == "cpu"
|
||||
assert subtask_attention_mask.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_device_cuda():
|
||||
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
|
||||
action = torch.randn(5).cuda() # CUDA tensor
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens are on CUDA
|
||||
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
assert subtask_tokens.device.type == "cuda"
|
||||
assert subtask_attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_preserves_other_observation_data():
|
||||
"""Test that subtask tokenization preserves other observation data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
original_state = torch.tensor([1.0, 2.0, 3.0])
|
||||
transition = create_transition(
|
||||
observation={"state": original_state.clone()},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that original observation data is preserved
|
||||
assert torch.equal(observation["state"], original_state)
|
||||
|
||||
# Check that both task and subtask tokens are present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_attention_mask_dtype():
|
||||
"""Test that subtask attention mask has correct dtype (bool)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert subtask_attention_mask.dtype == torch.bool
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_deterministic():
|
||||
"""Test that subtask tokenization is deterministic for the same input."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "consistent subtask"},
|
||||
)
|
||||
|
||||
result1 = processor(transition)
|
||||
result2 = processor(transition)
|
||||
|
||||
subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
# Results should be identical
|
||||
assert torch.equal(subtask_tokens1, subtask_tokens2)
|
||||
assert torch.equal(subtask_mask1, subtask_mask2)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
|
||||
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "subtask instruction"},
|
||||
)
|
||||
|
||||
result = robot_processor(transition)
|
||||
|
||||
# Check that observation exists and both tokenizations were applied
|
||||
assert TransitionKey.OBSERVATION in result
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check task tokens
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
|
||||
# Check subtask tokens
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check shapes
|
||||
assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,)
|
||||
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_not_added_for_unsupported_types():
|
||||
"""Test that subtask tokens are not added when subtask has unsupported type."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
# Test with integer subtask
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": 123},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Subtask tokens should NOT be added for unsupported types
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
# But main task tokens should still be present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
|
||||
Reference in New Issue
Block a user