From b18cef2e260a80db6cbe2327140950964c797b46 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 30 Jan 2026 10:29:37 -0800 Subject: [PATCH] feat(dataset): add subtask support (#2860) * add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs --- docs/source/_toctree.yml | 2 + docs/source/dataset_subtask.mdx | 278 +++++++++++ src/lerobot/datasets/lerobot_dataset.py | 9 + src/lerobot/datasets/utils.py | 9 + src/lerobot/processor/converters.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 46 ++ src/lerobot/utils/constants.py | 3 + tests/datasets/test_subtask_dataset.py | 190 ++++++++ tests/processor/test_tokenizer_processor.py | 465 ++++++++++++++++++- 9 files changed, 1003 insertions(+), 2 deletions(-) create mode 100644 docs/source/dataset_subtask.mdx create mode 100644 tests/datasets/test_subtask_dataset.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 98417f134..d61aac9c1 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets - local: using_dataset_tools title: Using the Dataset Tools + - local: dataset_subtask + title: Using Subtasks in the Dataset title: "Datasets" - sections: - local: act diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx new file mode 100644 index 000000000..beb5d80bd --- /dev/null +++ b/docs/source/dataset_subtask.mdx @@ -0,0 +1,278 @@ +# Using Subtasks in LeRobot Datasets + +Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for: + +- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time +- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models) +- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps + +LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks. + +## What are Subtasks? + +While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps: + +1. "Approach the apple" +2. "Grasp the apple" +3. "Lift the apple" +4. "Move to basket" +5. "Release the apple" + +Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages. + +An overview of subtask annotation showing how frames are labeled with intermediate subtask stages + +

+ Figure: Overview of subtask annotation. +

+ +**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022. + +## Dataset Structure + +Subtask information is stored in the dataset metadata: + +``` +my-dataset/ +├── data/ +│ └── ... +├── meta/ +│ ├── info.json +│ ├── stats.json +│ ├── tasks.parquet +│ ├── subtasks.parquet # Subtask index → subtask string mapping +│ └── episodes/ +│ └── ... +└── videos/ + └── ... +``` + +### Subtasks Parquet File + +The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions: + +| subtask_index | subtask (index column) | +| ------------- | ---------------------- | +| 0 | "Approach the apple" | +| 1 | "Grasp the apple" | +| 2 | "Lift the apple" | +| ... | ... | + +### Frame-Level Annotations + +Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file: + +```python +# Example frame data in the parquet file +{ + "index": 42, + "timestamp": 1.4, + "episode_index": 0, + "task_index": 0, + "subtask_index": 2, # References "Lift the apple" + "observation.state": [...], + "action": [...], +} +``` + +## Annotating Datasets with Subtasks + +We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks: + +**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)** + +After completing your annotation: + +1. Click "Push to Hub" to upload your annotated dataset +2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate) + +## Loading Datasets with Subtasks + +When you load a dataset with subtask annotations, the subtask information is automatically available: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load a dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Access a sample +sample = dataset[100] + +# The sample includes both task and subtask information +print(sample["task"]) # "Collect the fruit" +print(sample["subtask"]) # "Grasp the apple" +print(sample["task_index"]) # tensor(0) +print(sample["subtask_index"]) # tensor(2) +``` + +### Checking for Subtask Support + +You can check if a dataset has subtask annotations: + +```python +# Check if subtasks are available +has_subtasks = ( + "subtask_index" in dataset.features + and dataset.meta.subtasks is not None +) + +if has_subtasks: + print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks") + print("Subtasks:", list(dataset.meta.subtasks.index)) +``` + +## Using Subtasks for Training + +### With the Tokenizer Processor + +The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: + +```python +from lerobot.processor.tokenizer_processor import TokenizerProcessor +from lerobot.processor.pipeline import ProcessorPipeline + +# Create a tokenizer processor +tokenizer_processor = TokenizerProcessor( + tokenizer_name_or_path="google/paligemma-3b-pt-224", + padding="max_length", + max_length=64, +) + +# The processor will automatically tokenize subtasks if present in the batch +# and add them to the observation under: +# - "observation.subtask.tokens" +# - "observation.subtask.attention_mask" +``` + +When subtasks are available in the batch, the tokenizer processor adds: + +- `observation.subtask.tokens`: Tokenized subtask text +- `observation.subtask.attention_mask`: Attention mask for the subtask tokens + +### DataLoader with Subtasks + +```python +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=16, + shuffle=True, +) + +for batch in dataloader: + # Access subtask information in the batch + subtasks = batch["subtask"] # List of subtask strings + subtask_indices = batch["subtask_index"] # Tensor of subtask indices + + # Use for training hierarchical policies or reward models + print(f"Batch subtasks: {set(subtasks)}") +``` + +## Example Datasets with Subtask Annotations + +Try loading a dataset with subtask annotations: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Example dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Explore the subtasks +print("Available subtasks:") +for subtask_name in dataset.meta.subtasks.index: + print(f" - {subtask_name}") + +# Get subtask distribution +subtask_counts = {} +for i in range(len(dataset)): + sample = dataset[i] + subtask = sample["subtask"] + subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1 + +print("\nSubtask distribution:") +for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]): + print(f" {subtask}: {count} frames") +``` + +## Use Cases + +### 1. Hierarchical Policy Training + +Train policies that predict both actions and current subtask: + +```python +class HierarchicalPolicy(nn.Module): + def __init__(self, num_subtasks): + super().__init__() + self.action_head = nn.Linear(hidden_dim, action_dim) + self.subtask_head = nn.Linear(hidden_dim, num_subtasks) + + def forward(self, observations): + features = self.encoder(observations) + actions = self.action_head(features) + subtask_logits = self.subtask_head(features) + return actions, subtask_logits +``` + +### 2. Stage-Aware Reward Modeling (SARM) + +Build reward models that understand task progression: + +```python +# SARM predicts: +# - Stage: Which subtask is being executed (discrete) +# - Progress: How far along the subtask (continuous 0-1) + +class SARMRewardModel(nn.Module): + def forward(self, observations): + features = self.encoder(observations) + stage_logits = self.stage_classifier(features) + progress = self.progress_regressor(features) + return stage_logits, progress +``` + +### 3. Progress Visualization + +Monitor robot execution by tracking subtask progression: + +```python +def visualize_execution(model, observations): + for t, obs in enumerate(observations): + action, subtask_logits = model(obs) + predicted_subtask = subtask_names[subtask_logits.argmax()] + print(f"t={t}: Executing '{predicted_subtask}'") +``` + +## API Reference + +### LeRobotDataset Properties + +| Property | Type | Description | +| --------------------------- | ---------------------- | ------------------------------------------ | +| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices | +| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present | + +### Sample Keys + +When subtasks are available, each sample includes: + +| Key | Type | Description | +| --------------- | -------------- | ------------------------------------ | +| `subtask_index` | `torch.Tensor` | Integer index of the current subtask | +| `subtask` | `str` | Natural language subtask description | + +## Related Resources + +- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation +- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool +- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6798e7fd7..36bffa190 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -57,6 +57,7 @@ from lerobot.datasets.utils import ( load_info, load_nested_dataset, load_stats, + load_subtasks, load_tasks, update_chunk_file_indices, validate_episode_buffer, @@ -162,6 +163,7 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -518,6 +520,7 @@ class LeRobotDatasetMetadata: _validate_feature_names(features) obj.tasks = None + obj.subtasks = None obj.episodes = None obj.stats = None obj.info = create_empty_dataset_info( @@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self.features and self.meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name + return item def __repr__(self): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed678af6e..321ecedd5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -60,6 +60,7 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" @@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: return tasks +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. This function writes episode-level metadata to a single parquet file. diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 4f9485fee..18c7b0220 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} + return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5cd1bebb0..df559555a 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -34,6 +34,8 @@ from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, OBS_LANGUAGE_TOKENS, ) from lerobot.utils.import_utils import _transformers_available @@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get("subtask") + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. @@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + tokenized_subtask = self._tokenize_text(subtask) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_subtask = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask.items() + } + + # Add tokenized subtask to the observation + new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to( + dtype=torch.bool + ) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43a61b4f7..ecd54844c 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" +OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "." diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py new file mode 100644 index 000000000..f80a6c72d --- /dev/null +++ b/tests/datasets/test_subtask_dataset.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for subtask functionality in LeRobotDataset. + +These tests verify that: +- Subtask information is correctly loaded from datasets that have subtask data +- The __getitem__ method correctly adds subtask strings to returned items +- Subtask handling gracefully handles missing data +""" + +import pandas as pd +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +class TestSubtaskDataset: + """Tests for subtask handling in LeRobotDataset.""" + + @pytest.fixture + def subtask_dataset(self): + """Load the test subtask dataset from the hub.""" + # Use lerobot/pusht-subtask dataset with episode 1 + return LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + + def test_subtask_dataset_loads(self, subtask_dataset): + """Test that the subtask dataset loads successfully.""" + assert subtask_dataset is not None + assert len(subtask_dataset) > 0 + + def test_subtask_metadata_loaded(self, subtask_dataset): + """Test that subtask metadata is loaded when present in dataset.""" + # The dataset should have subtasks metadata loaded + assert subtask_dataset.meta.subtasks is not None + assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame) + + def test_subtask_index_in_features(self, subtask_dataset): + """Test that subtask_index is a feature when dataset has subtasks.""" + assert "subtask_index" in subtask_dataset.features + + def test_getitem_returns_subtask_string(self, subtask_dataset): + """Test that __getitem__ correctly adds subtask string to returned item.""" + item = subtask_dataset[0] + + # Subtask should be present in the returned item + assert "subtask" in item + assert isinstance(item["subtask"], str) + assert len(item["subtask"]) > 0 # Should not be empty + + def test_getitem_has_subtask_index(self, subtask_dataset): + """Test that __getitem__ includes subtask_index.""" + item = subtask_dataset[0] + + assert "subtask_index" in item + assert isinstance(item["subtask_index"], torch.Tensor) + + def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset): + """Test that subtask_index correctly maps to a subtask in metadata.""" + item = subtask_dataset[0] + + subtask_idx = item["subtask_index"].item() + subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name + + assert item["subtask"] == subtask_from_metadata + + def test_all_items_have_subtask(self, subtask_dataset): + """Test that all items in the dataset have subtask information.""" + for i in range(min(len(subtask_dataset), 5)): # Check first 5 items + item = subtask_dataset[i] + assert "subtask" in item + assert isinstance(item["subtask"], str) + + def test_task_and_subtask_coexist(self, subtask_dataset): + """Test that both task and subtask are present in returned items.""" + item = subtask_dataset[0] + + # Both task and subtask should be present + assert "task" in item + assert "subtask" in item + assert isinstance(item["task"], str) + assert isinstance(item["subtask"], str) + + +class TestSubtaskDatasetMissing: + """Tests for graceful handling when subtask data is missing.""" + + @pytest.fixture + def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory): + """Create a dataset without subtask information.""" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features) + + # Add some frames and save + for _ in range(5): + dataset.add_frame({"state": torch.randn(2), "task": "Test task"}) + dataset.save_episode() + dataset.finalize() + + # Reload the dataset + return LeRobotDataset(dataset.repo_id, root=dataset.root) + + def test_no_subtask_in_features(self, dataset_without_subtasks): + """Test that subtask_index is not in features when not provided.""" + assert "subtask_index" not in dataset_without_subtasks.features + + def test_getitem_without_subtask(self, dataset_without_subtasks): + """Test that __getitem__ works when subtask is not present.""" + item = dataset_without_subtasks[0] + + # Item should still be retrievable + assert item is not None + assert "state" in item + assert "task" in item + + # Subtask should NOT be present + assert "subtask" not in item + + def test_subtasks_metadata_is_none(self, dataset_without_subtasks): + """Test that subtasks metadata is None when not present.""" + assert dataset_without_subtasks.meta.subtasks is None + + +class TestSubtaskEdgeCases: + """Edge case tests for subtask handling.""" + + def test_subtask_with_multiple_episodes(self): + """Test subtask handling with multiple episodes if available.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + # Check first and last items have valid subtasks + first_item = dataset[0] + last_item = dataset[len(dataset) - 1] + + assert "subtask" in first_item + assert "subtask" in last_item + assert isinstance(first_item["subtask"], str) + assert isinstance(last_item["subtask"], str) + + def test_subtask_index_consistency(self): + """Test that same subtask_index returns same subtask string.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + if len(dataset) < 2: + pytest.skip("Dataset too small for this test") + + # Collect subtask_index to subtask mappings + subtask_map = {} + for i in range(min(len(dataset), 10)): + item = dataset[i] + idx = item["subtask_index"].item() + subtask = item["subtask"] + + if idx in subtask_map: + # Same index should always return same subtask + assert subtask_map[idx] == subtask, ( + f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'" + ) + else: + subtask_map[idx] = subtask diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index d6f87f567..64cc8aac8 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -27,7 +27,14 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGE, + OBS_LANGUAGE, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, + OBS_STATE, +) from tests.utils import require_package @@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario(): # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) assert tokens.shape == (10,) # MockTokenizer behavior for single string in list assert attention_mask.shape == (10,) + + +# ============================================================================= +# Tests for get_subtask method +# ============================================================================= + + +@require_package("transformers") +def test_get_subtask_missing_key(): + """Test get_subtask returns None when subtask key is missing from complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No "subtask" key + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_value(): + """Test get_subtask returns None when subtask value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_complementary_data(): + """Test get_subtask returns None when complementary_data is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data=None, # No complementary data + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_string(): + """Test get_subtask returns list with single string when subtask is a string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the cube"}, + ) + + result = processor.get_subtask(transition) + assert result == ["pick up the cube"] + assert isinstance(result, list) + assert len(result) == 1 + + +@require_package("transformers") +def test_get_subtask_list_of_strings(): + """Test get_subtask returns the list when subtask is already a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + subtask_list = ["pick up", "move to target", "place down"] + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": subtask_list}, + ) + + result = processor.get_subtask(transition) + assert result == subtask_list + assert isinstance(result, list) + assert len(result) == 3 + + +@require_package("transformers") +def test_get_subtask_unsupported_type_integer(): + """Test get_subtask returns None when subtask is an unsupported type (integer).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_mixed_list(): + """Test get_subtask returns None when subtask is a list with mixed types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_dict(): + """Test get_subtask returns None when subtask is a dictionary.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": {"key": "value"}}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_empty_string(): + """Test get_subtask with empty string returns list with empty string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ""}, + ) + + result = processor.get_subtask(transition) + assert result == [""] + + +@require_package("transformers") +def test_get_subtask_empty_list(): + """Test get_subtask with empty list returns empty list.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": []}, + ) + + result = processor.get_subtask(transition) + assert result == [] + + +# ============================================================================= +# Tests for subtask tokenization in observation method +# ============================================================================= + + +@require_package("transformers") +def test_subtask_tokenization_when_present(): + """Test that subtask is tokenized and added to observation when present.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert isinstance(subtask_tokens, torch.Tensor) + assert isinstance(subtask_attention_mask, torch.Tensor) + assert subtask_tokens.shape == (8,) + assert subtask_attention_mask.shape == (8,) + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_none(): + """Test that subtask tokens are NOT added to observation when subtask is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No subtask + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_subtask_value_is_none(): + """Test that subtask tokens are NOT added when subtask value is explicitly None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + +@require_package("transformers") +def test_subtask_tokenization_list_of_strings(): + """Test subtask tokenization with list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["pick up", "place down"]}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure for batch + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert subtask_attention_mask.shape == (2, 8) + + +@require_package("transformers") +def test_subtask_tokenization_device_cpu(): + """Test that subtask tokens are on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {OBS_STATE: torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CPU + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cpu" + assert subtask_attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_subtask_tokenization_device_cuda(): + """Test that subtask tokens are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CUDA + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cuda" + assert subtask_attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_subtask_tokenization_preserves_other_observation_data(): + """Test that subtask tokenization preserves other observation data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + original_state = torch.tensor([1.0, 2.0, 3.0]) + transition = create_transition( + observation={"state": original_state.clone()}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Check that original observation data is preserved + assert torch.equal(observation["state"], original_state) + + # Check that both task and subtask tokens are present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + +@require_package("transformers") +def test_subtask_attention_mask_dtype(): + """Test that subtask attention mask has correct dtype (bool).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_deterministic(): + """Test that subtask tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "consistent subtask"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + # Results should be identical + assert torch.equal(subtask_tokens1, subtask_tokens2) + assert torch.equal(subtask_mask1, subtask_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): + """Test subtask tokenization works correctly with DataProcessorPipeline.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "subtask instruction"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and both tokenizations were applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + + # Check task tokens + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check subtask tokens + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check shapes + assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,) + assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,) + + +@require_package("transformers") +def test_subtask_not_added_for_unsupported_types(): + """Test that subtask tokens are not added when subtask has unsupported type.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + # Test with integer subtask + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Subtask tokens should NOT be added for unsupported types + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation