feat(dataset): add subtask support (#2860)

* add subtask

* remove folder

* add docs

* update doc

* add testing

* update test

* update constant naming + doc

* more docs
This commit is contained in:
Jade Choghari
2026-01-30 10:29:37 -08:00
committed by GitHub
parent 5c6182176f
commit b18cef2e26
9 changed files with 1003 additions and 2 deletions
+2
View File
@@ -27,6 +27,8 @@
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
- local: dataset_subtask
title: Using Subtasks in the Dataset
title: "Datasets"
- sections:
- local: act
+278
View File
@@ -0,0 +1,278 @@
# Using Subtasks in LeRobot Datasets
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
## What are Subtasks?
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
1. "Approach the apple"
2. "Grasp the apple"
3. "Lift the apple"
4. "Move to basket"
5. "Release the apple"
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
width="80%"
/>
<p>
<em>Figure: Overview of subtask annotation.</em>
</p>
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
## Dataset Structure
Subtask information is stored in the dataset metadata:
```
my-dataset/
├── data/
│ └── ...
├── meta/
│ ├── info.json
│ ├── stats.json
│ ├── tasks.parquet
│ ├── subtasks.parquet # Subtask index → subtask string mapping
│ └── episodes/
│ └── ...
└── videos/
└── ...
```
### Subtasks Parquet File
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
| subtask_index | subtask (index column) |
| ------------- | ---------------------- |
| 0 | "Approach the apple" |
| 1 | "Grasp the apple" |
| 2 | "Lift the apple" |
| ... | ... |
### Frame-Level Annotations
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
```python
# Example frame data in the parquet file
{
"index": 42,
"timestamp": 1.4,
"episode_index": 0,
"task_index": 0,
"subtask_index": 2, # References "Lift the apple"
"observation.state": [...],
"action": [...],
}
```
## Annotating Datasets with Subtasks
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
After completing your annotation:
1. Click "Push to Hub" to upload your annotated dataset
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
## Loading Datasets with Subtasks
When you load a dataset with subtask annotations, the subtask information is automatically available:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load a dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Access a sample
sample = dataset[100]
# The sample includes both task and subtask information
print(sample["task"]) # "Collect the fruit"
print(sample["subtask"]) # "Grasp the apple"
print(sample["task_index"]) # tensor(0)
print(sample["subtask_index"]) # tensor(2)
```
### Checking for Subtask Support
You can check if a dataset has subtask annotations:
```python
# Check if subtasks are available
has_subtasks = (
"subtask_index" in dataset.features
and dataset.meta.subtasks is not None
)
if has_subtasks:
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
print("Subtasks:", list(dataset.meta.subtasks.index))
```
## Using Subtasks for Training
### With the Tokenizer Processor
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
```python
from lerobot.processor.tokenizer_processor import TokenizerProcessor
from lerobot.processor.pipeline import ProcessorPipeline
# Create a tokenizer processor
tokenizer_processor = TokenizerProcessor(
tokenizer_name_or_path="google/paligemma-3b-pt-224",
padding="max_length",
max_length=64,
)
# The processor will automatically tokenize subtasks if present in the batch
# and add them to the observation under:
# - "observation.subtask.tokens"
# - "observation.subtask.attention_mask"
```
When subtasks are available in the batch, the tokenizer processor adds:
- `observation.subtask.tokens`: Tokenized subtask text
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
### DataLoader with Subtasks
```python
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=16,
shuffle=True,
)
for batch in dataloader:
# Access subtask information in the batch
subtasks = batch["subtask"] # List of subtask strings
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
# Use for training hierarchical policies or reward models
print(f"Batch subtasks: {set(subtasks)}")
```
## Example Datasets with Subtask Annotations
Try loading a dataset with subtask annotations:
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Example dataset with subtask annotations
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
# Explore the subtasks
print("Available subtasks:")
for subtask_name in dataset.meta.subtasks.index:
print(f" - {subtask_name}")
# Get subtask distribution
subtask_counts = {}
for i in range(len(dataset)):
sample = dataset[i]
subtask = sample["subtask"]
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
print("\nSubtask distribution:")
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
print(f" {subtask}: {count} frames")
```
## Use Cases
### 1. Hierarchical Policy Training
Train policies that predict both actions and current subtask:
```python
class HierarchicalPolicy(nn.Module):
def __init__(self, num_subtasks):
super().__init__()
self.action_head = nn.Linear(hidden_dim, action_dim)
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
def forward(self, observations):
features = self.encoder(observations)
actions = self.action_head(features)
subtask_logits = self.subtask_head(features)
return actions, subtask_logits
```
### 2. Stage-Aware Reward Modeling (SARM)
Build reward models that understand task progression:
```python
# SARM predicts:
# - Stage: Which subtask is being executed (discrete)
# - Progress: How far along the subtask (continuous 0-1)
class SARMRewardModel(nn.Module):
def forward(self, observations):
features = self.encoder(observations)
stage_logits = self.stage_classifier(features)
progress = self.progress_regressor(features)
return stage_logits, progress
```
### 3. Progress Visualization
Monitor robot execution by tracking subtask progression:
```python
def visualize_execution(model, observations):
for t, obs in enumerate(observations):
action, subtask_logits = model(obs)
predicted_subtask = subtask_names[subtask_logits.argmax()]
print(f"t={t}: Executing '{predicted_subtask}'")
```
## API Reference
### LeRobotDataset Properties
| Property | Type | Description |
| --------------------------- | ---------------------- | ------------------------------------------ |
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
### Sample Keys
When subtasks are available, each sample includes:
| Key | Type | Description |
| --------------- | -------------- | ------------------------------------ |
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
| `subtask` | `str` | Natural language subtask description |
## Related Resources
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
+9
View File
@@ -57,6 +57,7 @@ from lerobot.datasets.utils import (
load_info,
load_nested_dataset,
load_stats,
load_subtasks,
load_tasks,
update_chunk_file_indices,
validate_episode_buffer,
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
self.info = load_info(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks = load_tasks(self.root)
self.subtasks = load_subtasks(self.root)
self.episodes = load_episodes(self.root)
self.stats = load_stats(self.root)
@@ -518,6 +520,7 @@ class LeRobotDatasetMetadata:
_validate_feature_names(features)
obj.tasks = None
obj.subtasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(
@@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks.iloc[task_idx].name
# add subtask information if available
if "subtask_index" in self.features and self.meta.subtasks is not None:
subtask_idx = item["subtask_index"].item()
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
return item
def __repr__(self):
+9
View File
@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
return tasks
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
"""Load subtasks from subtasks.parquet if it exists."""
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
if subtasks_path.exists():
return pd.read_parquet(subtasks_path)
return None
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
This function writes episode-level metadata to a single parquet file.
+2 -1
View File
@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
"""
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
task_key = {"task": batch["task"]} if "task" in batch else {}
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
index_key = {"index": batch["index"]} if "index" in batch else {}
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
def create_transition(
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_LANGUAGE_TOKENS,
)
from lerobot.utils.import_utils import _transformers_available
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
return None
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
"""
Extracts the subtask from the transition's complementary data.
Args:
transition: The environment transition.
Returns:
A list of subtask strings, or None if the subtask key is not found or the value is None.
"""
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
if complementary_data is None:
return None
subtask = complementary_data.get("subtask")
if subtask is None:
return None
# Standardize to a list of strings for the tokenizer
if isinstance(subtask, str):
return [subtask]
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
return subtask
return None
def observation(self, observation: RobotObservation) -> RobotObservation:
"""
Tokenizes the task description and adds it to the observation dictionary.
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
# Tokenize subtask if available
subtask = self.get_subtask(self.transition)
if subtask is not None:
tokenized_subtask = self._tokenize_text(subtask)
# Move new tokenized tensors to the detected device
if target_device is not None:
tokenized_subtask = {
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
for k, v in tokenized_subtask.items()
}
# Add tokenized subtask to the observation
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
dtype=torch.bool
)
return new_observation
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
+3
View File
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
OBS_LANGUAGE = OBS_STR + ".language"
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
ACTION = "action"
ACTION_PREFIX = ACTION + "."
+190
View File
@@ -0,0 +1,190 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for subtask functionality in LeRobotDataset.
These tests verify that:
- Subtask information is correctly loaded from datasets that have subtask data
- The __getitem__ method correctly adds subtask strings to returned items
- Subtask handling gracefully handles missing data
"""
import pandas as pd
import pytest
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
class TestSubtaskDataset:
"""Tests for subtask handling in LeRobotDataset."""
@pytest.fixture
def subtask_dataset(self):
"""Load the test subtask dataset from the hub."""
# Use lerobot/pusht-subtask dataset with episode 1
return LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
def test_subtask_dataset_loads(self, subtask_dataset):
"""Test that the subtask dataset loads successfully."""
assert subtask_dataset is not None
assert len(subtask_dataset) > 0
def test_subtask_metadata_loaded(self, subtask_dataset):
"""Test that subtask metadata is loaded when present in dataset."""
# The dataset should have subtasks metadata loaded
assert subtask_dataset.meta.subtasks is not None
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
def test_subtask_index_in_features(self, subtask_dataset):
"""Test that subtask_index is a feature when dataset has subtasks."""
assert "subtask_index" in subtask_dataset.features
def test_getitem_returns_subtask_string(self, subtask_dataset):
"""Test that __getitem__ correctly adds subtask string to returned item."""
item = subtask_dataset[0]
# Subtask should be present in the returned item
assert "subtask" in item
assert isinstance(item["subtask"], str)
assert len(item["subtask"]) > 0 # Should not be empty
def test_getitem_has_subtask_index(self, subtask_dataset):
"""Test that __getitem__ includes subtask_index."""
item = subtask_dataset[0]
assert "subtask_index" in item
assert isinstance(item["subtask_index"], torch.Tensor)
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
"""Test that subtask_index correctly maps to a subtask in metadata."""
item = subtask_dataset[0]
subtask_idx = item["subtask_index"].item()
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
assert item["subtask"] == subtask_from_metadata
def test_all_items_have_subtask(self, subtask_dataset):
"""Test that all items in the dataset have subtask information."""
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
item = subtask_dataset[i]
assert "subtask" in item
assert isinstance(item["subtask"], str)
def test_task_and_subtask_coexist(self, subtask_dataset):
"""Test that both task and subtask are present in returned items."""
item = subtask_dataset[0]
# Both task and subtask should be present
assert "task" in item
assert "subtask" in item
assert isinstance(item["task"], str)
assert isinstance(item["subtask"], str)
class TestSubtaskDatasetMissing:
"""Tests for graceful handling when subtask data is missing."""
@pytest.fixture
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
"""Create a dataset without subtask information."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
# Add some frames and save
for _ in range(5):
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
dataset.save_episode()
dataset.finalize()
# Reload the dataset
return LeRobotDataset(dataset.repo_id, root=dataset.root)
def test_no_subtask_in_features(self, dataset_without_subtasks):
"""Test that subtask_index is not in features when not provided."""
assert "subtask_index" not in dataset_without_subtasks.features
def test_getitem_without_subtask(self, dataset_without_subtasks):
"""Test that __getitem__ works when subtask is not present."""
item = dataset_without_subtasks[0]
# Item should still be retrievable
assert item is not None
assert "state" in item
assert "task" in item
# Subtask should NOT be present
assert "subtask" not in item
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
"""Test that subtasks metadata is None when not present."""
assert dataset_without_subtasks.meta.subtasks is None
class TestSubtaskEdgeCases:
"""Edge case tests for subtask handling."""
def test_subtask_with_multiple_episodes(self):
"""Test subtask handling with multiple episodes if available."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
# Check first and last items have valid subtasks
first_item = dataset[0]
last_item = dataset[len(dataset) - 1]
assert "subtask" in first_item
assert "subtask" in last_item
assert isinstance(first_item["subtask"], str)
assert isinstance(last_item["subtask"], str)
def test_subtask_index_consistency(self):
"""Test that same subtask_index returns same subtask string."""
try:
dataset = LeRobotDataset(
repo_id="lerobot/pusht-subtask",
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
)
except Exception:
pytest.skip("Could not load test-subtask dataset")
if len(dataset) < 2:
pytest.skip("Dataset too small for this test")
# Collect subtask_index to subtask mappings
subtask_map = {}
for i in range(min(len(dataset), 10)):
item = dataset[i]
idx = item["subtask_index"].item()
subtask = item["subtask"]
if idx in subtask_map:
# Same index should always return same subtask
assert subtask_map[idx] == subtask, (
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
)
else:
subtask_map[idx] = subtask
+464 -1
View File
@@ -27,7 +27,14 @@ import torch
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
from lerobot.processor.converters import create_transition, identity_transition
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
from lerobot.utils.constants import (
ACTION,
OBS_IMAGE,
OBS_LANGUAGE,
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
OBS_LANGUAGE_SUBTASK_TOKENS,
OBS_STATE,
)
from tests.utils import require_package
@@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario():
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
assert attention_mask.shape == (10,)
# =============================================================================
# Tests for get_subtask method
# =============================================================================
@require_package("transformers")
def test_get_subtask_missing_key():
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No "subtask" key
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_value():
"""Test get_subtask returns None when subtask value is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_none_complementary_data():
"""Test get_subtask returns None when complementary_data is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data=None, # No complementary data
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_string():
"""Test get_subtask returns list with single string when subtask is a string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the cube"},
)
result = processor.get_subtask(transition)
assert result == ["pick up the cube"]
assert isinstance(result, list)
assert len(result) == 1
@require_package("transformers")
def test_get_subtask_list_of_strings():
"""Test get_subtask returns the list when subtask is already a list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
subtask_list = ["pick up", "move to target", "place down"]
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": subtask_list},
)
result = processor.get_subtask(transition)
assert result == subtask_list
assert isinstance(result, list)
assert len(result) == 3
@require_package("transformers")
def test_get_subtask_unsupported_type_integer():
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_mixed_list():
"""Test get_subtask returns None when subtask is a list with mixed types."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_unsupported_type_dict():
"""Test get_subtask returns None when subtask is a dictionary."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": {"key": "value"}},
)
result = processor.get_subtask(transition)
assert result is None
@require_package("transformers")
def test_get_subtask_empty_string():
"""Test get_subtask with empty string returns list with empty string."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ""},
)
result = processor.get_subtask(transition)
assert result == [""]
@require_package("transformers")
def test_get_subtask_empty_list():
"""Test get_subtask with empty list returns empty list."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": []},
)
result = processor.get_subtask(transition)
assert result == []
# =============================================================================
# Tests for subtask tokenization in observation method
# =============================================================================
@require_package("transformers")
def test_subtask_tokenization_when_present():
"""Test that subtask is tokenized and added to observation when present."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up the red cube"},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert isinstance(subtask_tokens, torch.Tensor)
assert isinstance(subtask_attention_mask, torch.Tensor)
assert subtask_tokens.shape == (8,)
assert subtask_attention_mask.shape == (8,)
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_not_added_when_none():
"""Test that subtask tokens are NOT added to observation when subtask is None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task"}, # No subtask
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
@require_package("transformers")
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": None},
)
result = processor(transition)
# Check that subtask tokens were NOT added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
@require_package("transformers")
def test_subtask_tokenization_list_of_strings():
"""Test subtask tokenization with list of strings."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": ["pick up", "place down"]},
)
result = processor(transition)
# Check that subtask tokens were added to observation
observation = result[TransitionKey.OBSERVATION]
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check token structure for batch
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8
assert subtask_attention_mask.shape == (2, 8)
@require_package("transformers")
def test_subtask_tokenization_device_cpu():
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CPU tensors
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
action = torch.randn(5) # CPU tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CPU
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cpu"
assert subtask_attention_mask.device.type == "cpu"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@require_package("transformers")
def test_subtask_tokenization_device_cuda():
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
# Create transition with CUDA tensors
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
action = torch.randn(5).cuda() # CUDA tensor
transition = create_transition(
observation=observation,
action=action,
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
# Check that subtask tokens are on CUDA
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_tokens.device.type == "cuda"
assert subtask_attention_mask.device.type == "cuda"
@require_package("transformers")
def test_subtask_tokenization_preserves_other_observation_data():
"""Test that subtask tokenization preserves other observation data."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
original_state = torch.tensor([1.0, 2.0, 3.0])
transition = create_transition(
observation={"state": original_state.clone()},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Check that original observation data is preserved
assert torch.equal(observation["state"], original_state)
# Check that both task and subtask tokens are present
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
@require_package("transformers")
def test_subtask_attention_mask_dtype():
"""Test that subtask attention mask has correct dtype (bool)."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "pick up cube"},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
assert subtask_attention_mask.dtype == torch.bool
@require_package("transformers")
def test_subtask_tokenization_deterministic():
"""Test that subtask tokenization is deterministic for the same input."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "consistent subtask"},
)
result1 = processor(transition)
result2 = processor(transition)
subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
# Results should be identical
assert torch.equal(subtask_tokens1, subtask_tokens2)
assert torch.equal(subtask_mask1, subtask_mask2)
@require_package("transformers")
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
mock_tokenizer = MockTokenizer(vocab_size=100)
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
robot_processor = DataProcessorPipeline(
[tokenizer_processor], to_transition=identity_transition, to_output=identity_transition
)
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": "subtask instruction"},
)
result = robot_processor(transition)
# Check that observation exists and both tokenizations were applied
assert TransitionKey.OBSERVATION in result
observation = result[TransitionKey.OBSERVATION]
# Check task tokens
assert f"{OBS_LANGUAGE}.tokens" in observation
assert f"{OBS_LANGUAGE}.attention_mask" in observation
# Check subtask tokens
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
# Check shapes
assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,)
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
@require_package("transformers")
def test_subtask_not_added_for_unsupported_types():
"""Test that subtask tokens are not added when subtask has unsupported type."""
mock_tokenizer = MockTokenizer(vocab_size=100)
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
# Test with integer subtask
transition = create_transition(
observation={"state": torch.tensor([1.0, 2.0])},
action=torch.tensor([0.1, 0.2]),
complementary_data={"task": "main task", "subtask": 123},
)
result = processor(transition)
observation = result[TransitionKey.OBSERVATION]
# Subtask tokens should NOT be added for unsupported types
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
# But main task tokens should still be present
assert f"{OBS_LANGUAGE}.tokens" in observation