feat(dataset): add subtask support (#2860)

* add subtask

* remove folder

* add docs

* update doc

* add testing

* update test

* update constant naming + doc

* more docs
This commit is contained in:
Jade Choghari
2026-01-30 10:29:37 -08:00
committed by GitHub
parent 5c6182176f
commit b18cef2e26
9 changed files with 1003 additions and 2 deletions
+190
View File
@@ -0,0 +1,190 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pandas as pd
import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
+464 -1
View File
@@ -27,7 +27,14 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_IMAGE,
OBS_LANGUAGE,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_STATE,
)
from tests.utils import require_package
@@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario():
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
assert attention_mask.shape == (10,)
# =============================================================================
# Tests for get_subtask method
# =============================================================================
@require_package("transformers")
def test_get_subtask_missing_key():
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No "subtask" key
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_value():
"""Test get_subtask returns None when subtask value is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_complementary_data():
"""Test get_subtask returns None when complementary_data is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data=None, # No complementary data
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_string():
"""Test get_subtask returns list with single string when subtask is a string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the cube"},
)
result = processor.get_subtask(transition)
assert result == ["pick up the cube"]
assert isinstance(result, list)
assert len(result) == 1
@require_package("transformers")
def test_get_subtask_list_of_strings():
"""Test get_subtask returns the list when subtask is already a list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
subtask_list = ["pick up", "move to target", "place down"]
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": subtask_list},
)
result = processor.get_subtask(transition)
assert result == subtask_list
assert isinstance(result, list)
assert len(result) == 3
@require_package("transformers")
def test_get_subtask_unsupported_type_integer():
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_mixed_list():
"""Test get_subtask returns None when subtask is a list with mixed types."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_dict():
"""Test get_subtask returns None when subtask is a dictionary."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": {"key": "value"}},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_empty_string():
"""Test get_subtask with empty string returns list with empty string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ""},
)
result = processor.get_subtask(transition)
assert result == [""]
@require_package("transformers")
def test_get_subtask_empty_list():
"""Test get_subtask with empty list returns empty list."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": []},
)
result = processor.get_subtask(transition)
assert result == []
# =============================================================================
# Tests for subtask tokenization in observation method
# =============================================================================
@require_package("transformers")
def test_subtask_tokenization_when_present():
"""Test that subtask is tokenized and added to observation when present."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the red cube"},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert isinstance(subtask_tokens, torch.Tensor)
assert isinstance(subtask_attention_mask, torch.Tensor)
assert subtask_tokens.shape == (8,)
assert subtask_attention_mask.shape == (8,)
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_not_added_when_none():
"""Test that subtask tokens are NOT added to observation when subtask is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No subtask
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
@require_package("transformers")
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
@require_package("transformers")
def test_subtask_tokenization_list_of_strings():
"""Test subtask tokenization with list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["pick up", "place down"]},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure for batch
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8
assert subtask_attention_mask.shape == (2, 8)
@require_package("transformers")
def test_subtask_tokenization_device_cpu():
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CPU tensors
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
action = torch.randn(5) # CPU tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CPU
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cpu"
assert subtask_attention_mask.device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@require_package("transformers")
def test_subtask_tokenization_device_cuda():
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CUDA tensors
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
action = torch.randn(5).cuda() # CUDA tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CUDA
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cuda"
assert subtask_attention_mask.device.type == "cuda"
@require_package("transformers")
def test_subtask_tokenization_preserves_other_observation_data():
"""Test that subtask tokenization preserves other observation data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
original_state = torch.tensor([1.0, 2.0, 3.0])
transition = create_transition(
observation={"state": original_state.clone()},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Check that original observation data is preserved
assert torch.equal(observation["state"], original_state)
# Check that both task and subtask tokens are present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
@require_package("transformers")
def test_subtask_attention_mask_dtype():
"""Test that subtask attention mask has correct dtype (bool)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_deterministic():
"""Test that subtask tokenization is deterministic for the same input."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "consistent subtask"},
)
result1 = processor(transition)
result2 = processor(transition)
subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
# Results should be identical
assert torch.equal(subtask_tokens1, subtask_tokens2)
assert torch.equal(subtask_mask1, subtask_mask2)
@require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = DataProcessorPipeline(
[tokenizer_processor], to_transition=identity_transition, to_output=identity_transition
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "subtask instruction"},
)
result = robot_processor(transition)
# Check that observation exists and both tokenizations were applied
assert TransitionKey.OBSERVATION in result
observation = result[TransitionKey.OBSERVATION]
# Check task tokens
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
# Check subtask tokens
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check shapes
assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,)
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
@require_package("transformers")
def test_subtask_not_added_for_unsupported_types():
"""Test that subtask tokens are not added when subtask has unsupported type."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
# Test with integer subtask
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Subtask tokens should NOT be added for unsupported types
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation