mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
feat(dataset): add subtask support (#2860)
* add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs
This commit is contained in:
@@ -27,6 +27,8 @@
|
||||
title: Porting Large Datasets
|
||||
- local: using_dataset_tools
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
# Using Subtasks in LeRobot Datasets
|
||||
|
||||
Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for:
|
||||
|
||||
- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time
|
||||
- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models)
|
||||
- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps
|
||||
|
||||
LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks.
|
||||
|
||||
## What are Subtasks?
|
||||
|
||||
While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps:
|
||||
|
||||
1. "Approach the apple"
|
||||
2. "Grasp the apple"
|
||||
3. "Lift the apple"
|
||||
4. "Move to basket"
|
||||
5. "Release the apple"
|
||||
|
||||
Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages.
|
||||
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/subtask-asset.png"
|
||||
alt="An overview of subtask annotation showing how frames are labeled with intermediate subtask stages"
|
||||
width="80%"
|
||||
/>
|
||||
|
||||
<p>
|
||||
<em>Figure: Overview of subtask annotation.</em>
|
||||
</p>
|
||||
|
||||
**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022.
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
Subtask information is stored in the dataset metadata:
|
||||
|
||||
```
|
||||
my-dataset/
|
||||
├── data/
|
||||
│ └── ...
|
||||
├── meta/
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ ├── tasks.parquet
|
||||
│ ├── subtasks.parquet # Subtask index → subtask string mapping
|
||||
│ └── episodes/
|
||||
│ └── ...
|
||||
└── videos/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Subtasks Parquet File
|
||||
|
||||
The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions:
|
||||
|
||||
| subtask_index | subtask (index column) |
|
||||
| ------------- | ---------------------- |
|
||||
| 0 | "Approach the apple" |
|
||||
| 1 | "Grasp the apple" |
|
||||
| 2 | "Lift the apple" |
|
||||
| ... | ... |
|
||||
|
||||
### Frame-Level Annotations
|
||||
|
||||
Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file:
|
||||
|
||||
```python
|
||||
# Example frame data in the parquet file
|
||||
{
|
||||
"index": 42,
|
||||
"timestamp": 1.4,
|
||||
"episode_index": 0,
|
||||
"task_index": 0,
|
||||
"subtask_index": 2, # References "Lift the apple"
|
||||
"observation.state": [...],
|
||||
"action": [...],
|
||||
}
|
||||
```
|
||||
|
||||
## Annotating Datasets with Subtasks
|
||||
|
||||
We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks:
|
||||
|
||||
**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)**
|
||||
|
||||
After completing your annotation:
|
||||
|
||||
1. Click "Push to Hub" to upload your annotated dataset
|
||||
2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate)
|
||||
|
||||
## Loading Datasets with Subtasks
|
||||
|
||||
When you load a dataset with subtask annotations, the subtask information is automatically available:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load a dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Access a sample
|
||||
sample = dataset[100]
|
||||
|
||||
# The sample includes both task and subtask information
|
||||
print(sample["task"]) # "Collect the fruit"
|
||||
print(sample["subtask"]) # "Grasp the apple"
|
||||
print(sample["task_index"]) # tensor(0)
|
||||
print(sample["subtask_index"]) # tensor(2)
|
||||
```
|
||||
|
||||
### Checking for Subtask Support
|
||||
|
||||
You can check if a dataset has subtask annotations:
|
||||
|
||||
```python
|
||||
# Check if subtasks are available
|
||||
has_subtasks = (
|
||||
"subtask_index" in dataset.features
|
||||
and dataset.meta.subtasks is not None
|
||||
)
|
||||
|
||||
if has_subtasks:
|
||||
print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks")
|
||||
print("Subtasks:", list(dataset.meta.subtasks.index))
|
||||
```
|
||||
|
||||
## Using Subtasks for Training
|
||||
|
||||
### With the Tokenizer Processor
|
||||
|
||||
The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models:
|
||||
|
||||
```python
|
||||
from lerobot.processor.tokenizer_processor import TokenizerProcessor
|
||||
from lerobot.processor.pipeline import ProcessorPipeline
|
||||
|
||||
# Create a tokenizer processor
|
||||
tokenizer_processor = TokenizerProcessor(
|
||||
tokenizer_name_or_path="google/paligemma-3b-pt-224",
|
||||
padding="max_length",
|
||||
max_length=64,
|
||||
)
|
||||
|
||||
# The processor will automatically tokenize subtasks if present in the batch
|
||||
# and add them to the observation under:
|
||||
# - "observation.subtask.tokens"
|
||||
# - "observation.subtask.attention_mask"
|
||||
```
|
||||
|
||||
When subtasks are available in the batch, the tokenizer processor adds:
|
||||
|
||||
- `observation.subtask.tokens`: Tokenized subtask text
|
||||
- `observation.subtask.attention_mask`: Attention mask for the subtask tokens
|
||||
|
||||
### DataLoader with Subtasks
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=16,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
# Access subtask information in the batch
|
||||
subtasks = batch["subtask"] # List of subtask strings
|
||||
subtask_indices = batch["subtask_index"] # Tensor of subtask indices
|
||||
|
||||
# Use for training hierarchical policies or reward models
|
||||
print(f"Batch subtasks: {set(subtasks)}")
|
||||
```
|
||||
|
||||
## Example Datasets with Subtask Annotations
|
||||
|
||||
Try loading a dataset with subtask annotations:
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Example dataset with subtask annotations
|
||||
dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated")
|
||||
|
||||
# Explore the subtasks
|
||||
print("Available subtasks:")
|
||||
for subtask_name in dataset.meta.subtasks.index:
|
||||
print(f" - {subtask_name}")
|
||||
|
||||
# Get subtask distribution
|
||||
subtask_counts = {}
|
||||
for i in range(len(dataset)):
|
||||
sample = dataset[i]
|
||||
subtask = sample["subtask"]
|
||||
subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1
|
||||
|
||||
print("\nSubtask distribution:")
|
||||
for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]):
|
||||
print(f" {subtask}: {count} frames")
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Hierarchical Policy Training
|
||||
|
||||
Train policies that predict both actions and current subtask:
|
||||
|
||||
```python
|
||||
class HierarchicalPolicy(nn.Module):
|
||||
def __init__(self, num_subtasks):
|
||||
super().__init__()
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.subtask_head = nn.Linear(hidden_dim, num_subtasks)
|
||||
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
actions = self.action_head(features)
|
||||
subtask_logits = self.subtask_head(features)
|
||||
return actions, subtask_logits
|
||||
```
|
||||
|
||||
### 2. Stage-Aware Reward Modeling (SARM)
|
||||
|
||||
Build reward models that understand task progression:
|
||||
|
||||
```python
|
||||
# SARM predicts:
|
||||
# - Stage: Which subtask is being executed (discrete)
|
||||
# - Progress: How far along the subtask (continuous 0-1)
|
||||
|
||||
class SARMRewardModel(nn.Module):
|
||||
def forward(self, observations):
|
||||
features = self.encoder(observations)
|
||||
stage_logits = self.stage_classifier(features)
|
||||
progress = self.progress_regressor(features)
|
||||
return stage_logits, progress
|
||||
```
|
||||
|
||||
### 3. Progress Visualization
|
||||
|
||||
Monitor robot execution by tracking subtask progression:
|
||||
|
||||
```python
|
||||
def visualize_execution(model, observations):
|
||||
for t, obs in enumerate(observations):
|
||||
action, subtask_logits = model(obs)
|
||||
predicted_subtask = subtask_names[subtask_logits.argmax()]
|
||||
print(f"t={t}: Executing '{predicted_subtask}'")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### LeRobotDataset Properties
|
||||
|
||||
| Property | Type | Description |
|
||||
| --------------------------- | ---------------------- | ------------------------------------------ |
|
||||
| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices |
|
||||
| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present |
|
||||
|
||||
### Sample Keys
|
||||
|
||||
When subtasks are available, each sample includes:
|
||||
|
||||
| Key | Type | Description |
|
||||
| --------------- | -------------- | ------------------------------------ |
|
||||
| `subtask_index` | `torch.Tensor` | Integer index of the current subtask |
|
||||
| `subtask` | `str` | Natural language subtask description |
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation
|
||||
- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool
|
||||
- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation
|
||||
@@ -57,6 +57,7 @@ from lerobot.datasets.utils import (
|
||||
load_info,
|
||||
load_nested_dataset,
|
||||
load_stats,
|
||||
load_subtasks,
|
||||
load_tasks,
|
||||
update_chunk_file_indices,
|
||||
validate_episode_buffer,
|
||||
@@ -162,6 +163,7 @@ class LeRobotDatasetMetadata:
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.subtasks = load_subtasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
|
||||
@@ -518,6 +520,7 @@ class LeRobotDatasetMetadata:
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
obj.subtasks = None
|
||||
obj.episodes = None
|
||||
obj.stats = None
|
||||
obj.info = create_empty_dataset_info(
|
||||
@@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
|
||||
# add subtask information if available
|
||||
if "subtask_index" in self.features and self.meta.subtasks is not None:
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -60,6 +60,7 @@ VIDEO_DIR = "videos"
|
||||
|
||||
CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}"
|
||||
DEFAULT_TASKS_PATH = "meta/tasks.parquet"
|
||||
DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet"
|
||||
DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
@@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame:
|
||||
return tasks
|
||||
|
||||
|
||||
def load_subtasks(local_dir: Path) -> pandas.DataFrame | None:
|
||||
"""Load subtasks from subtasks.parquet if it exists."""
|
||||
subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH
|
||||
if subtasks_path.exists():
|
||||
return pd.read_parquet(subtasks_path)
|
||||
return None
|
||||
|
||||
|
||||
def write_episodes(episodes: Dataset, local_dir: Path) -> None:
|
||||
"""Write episode metadata to a parquet file in the LeRobot v3.0 format.
|
||||
This function writes episode-level metadata to a single parquet file.
|
||||
|
||||
@@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -34,6 +34,8 @@ from lerobot.utils.constants import (
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
)
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
@@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
|
||||
return None
|
||||
|
||||
def get_subtask(self, transition: EnvTransition) -> list[str] | None:
|
||||
"""
|
||||
Extracts the subtask from the transition's complementary data.
|
||||
|
||||
Args:
|
||||
transition: The environment transition.
|
||||
|
||||
Returns:
|
||||
A list of subtask strings, or None if the subtask key is not found or the value is None.
|
||||
"""
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA)
|
||||
if complementary_data is None:
|
||||
return None
|
||||
|
||||
subtask = complementary_data.get("subtask")
|
||||
if subtask is None:
|
||||
return None
|
||||
|
||||
# Standardize to a list of strings for the tokenizer
|
||||
if isinstance(subtask, str):
|
||||
return [subtask]
|
||||
elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask):
|
||||
return subtask
|
||||
|
||||
return None
|
||||
|
||||
def observation(self, observation: RobotObservation) -> RobotObservation:
|
||||
"""
|
||||
Tokenizes the task description and adds it to the observation dictionary.
|
||||
@@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep):
|
||||
new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool)
|
||||
|
||||
# Tokenize subtask if available
|
||||
subtask = self.get_subtask(self.transition)
|
||||
if subtask is not None:
|
||||
tokenized_subtask = self._tokenize_text(subtask)
|
||||
|
||||
# Move new tokenized tensors to the detected device
|
||||
if target_device is not None:
|
||||
tokenized_subtask = {
|
||||
k: v.to(target_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in tokenized_subtask.items()
|
||||
}
|
||||
|
||||
# Add tokenized subtask to the observation
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"]
|
||||
new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to(
|
||||
dtype=torch.bool
|
||||
)
|
||||
|
||||
return new_observation
|
||||
|
||||
def _detect_device(self, transition: EnvTransition) -> torch.device | None:
|
||||
|
||||
@@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s"
|
||||
OBS_LANGUAGE = OBS_STR + ".language"
|
||||
OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens"
|
||||
OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask"
|
||||
OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask"
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens"
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask"
|
||||
|
||||
ACTION = "action"
|
||||
ACTION_PREFIX = ACTION + "."
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for subtask functionality in LeRobotDataset.
|
||||
|
||||
These tests verify that:
|
||||
- Subtask information is correctly loaded from datasets that have subtask data
|
||||
- The __getitem__ method correctly adds subtask strings to returned items
|
||||
- Subtask handling gracefully handles missing data
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class TestSubtaskDataset:
|
||||
"""Tests for subtask handling in LeRobotDataset."""
|
||||
|
||||
@pytest.fixture
|
||||
def subtask_dataset(self):
|
||||
"""Load the test subtask dataset from the hub."""
|
||||
# Use lerobot/pusht-subtask dataset with episode 1
|
||||
return LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
|
||||
def test_subtask_dataset_loads(self, subtask_dataset):
|
||||
"""Test that the subtask dataset loads successfully."""
|
||||
assert subtask_dataset is not None
|
||||
assert len(subtask_dataset) > 0
|
||||
|
||||
def test_subtask_metadata_loaded(self, subtask_dataset):
|
||||
"""Test that subtask metadata is loaded when present in dataset."""
|
||||
# The dataset should have subtasks metadata loaded
|
||||
assert subtask_dataset.meta.subtasks is not None
|
||||
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
|
||||
|
||||
def test_subtask_index_in_features(self, subtask_dataset):
|
||||
"""Test that subtask_index is a feature when dataset has subtasks."""
|
||||
assert "subtask_index" in subtask_dataset.features
|
||||
|
||||
def test_getitem_returns_subtask_string(self, subtask_dataset):
|
||||
"""Test that __getitem__ correctly adds subtask string to returned item."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Subtask should be present in the returned item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
assert len(item["subtask"]) > 0 # Should not be empty
|
||||
|
||||
def test_getitem_has_subtask_index(self, subtask_dataset):
|
||||
"""Test that __getitem__ includes subtask_index."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
assert "subtask_index" in item
|
||||
assert isinstance(item["subtask_index"], torch.Tensor)
|
||||
|
||||
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
|
||||
"""Test that subtask_index correctly maps to a subtask in metadata."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
assert item["subtask"] == subtask_from_metadata
|
||||
|
||||
def test_all_items_have_subtask(self, subtask_dataset):
|
||||
"""Test that all items in the dataset have subtask information."""
|
||||
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
|
||||
item = subtask_dataset[i]
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
def test_task_and_subtask_coexist(self, subtask_dataset):
|
||||
"""Test that both task and subtask are present in returned items."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Both task and subtask should be present
|
||||
assert "task" in item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["task"], str)
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
|
||||
class TestSubtaskDatasetMissing:
|
||||
"""Tests for graceful handling when subtask data is missing."""
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Create a dataset without subtask information."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
|
||||
|
||||
# Add some frames and save
|
||||
for _ in range(5):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Reload the dataset
|
||||
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
def test_no_subtask_in_features(self, dataset_without_subtasks):
|
||||
"""Test that subtask_index is not in features when not provided."""
|
||||
assert "subtask_index" not in dataset_without_subtasks.features
|
||||
|
||||
def test_getitem_without_subtask(self, dataset_without_subtasks):
|
||||
"""Test that __getitem__ works when subtask is not present."""
|
||||
item = dataset_without_subtasks[0]
|
||||
|
||||
# Item should still be retrievable
|
||||
assert item is not None
|
||||
assert "state" in item
|
||||
assert "task" in item
|
||||
|
||||
# Subtask should NOT be present
|
||||
assert "subtask" not in item
|
||||
|
||||
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
|
||||
"""Test that subtasks metadata is None when not present."""
|
||||
assert dataset_without_subtasks.meta.subtasks is None
|
||||
|
||||
|
||||
class TestSubtaskEdgeCases:
|
||||
"""Edge case tests for subtask handling."""
|
||||
|
||||
def test_subtask_with_multiple_episodes(self):
|
||||
"""Test subtask handling with multiple episodes if available."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
# Check first and last items have valid subtasks
|
||||
first_item = dataset[0]
|
||||
last_item = dataset[len(dataset) - 1]
|
||||
|
||||
assert "subtask" in first_item
|
||||
assert "subtask" in last_item
|
||||
assert isinstance(first_item["subtask"], str)
|
||||
assert isinstance(last_item["subtask"], str)
|
||||
|
||||
def test_subtask_index_consistency(self):
|
||||
"""Test that same subtask_index returns same subtask string."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
if len(dataset) < 2:
|
||||
pytest.skip("Dataset too small for this test")
|
||||
|
||||
# Collect subtask_index to subtask mappings
|
||||
subtask_map = {}
|
||||
for i in range(min(len(dataset), 10)):
|
||||
item = dataset[i]
|
||||
idx = item["subtask_index"].item()
|
||||
subtask = item["subtask"]
|
||||
|
||||
if idx in subtask_map:
|
||||
# Same index should always return same subtask
|
||||
assert subtask_map[idx] == subtask, (
|
||||
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
|
||||
)
|
||||
else:
|
||||
subtask_map[idx] = subtask
|
||||
@@ -27,7 +27,14 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_IMAGE,
|
||||
OBS_LANGUAGE,
|
||||
OBS_LANGUAGE_SUBTASK_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_SUBTASK_TOKENS,
|
||||
OBS_STATE,
|
||||
)
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario():
|
||||
# MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length)
|
||||
assert tokens.shape == (10,) # MockTokenizer behavior for single string in list
|
||||
assert attention_mask.shape == (10,)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_subtask method
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_missing_key():
|
||||
"""Test get_subtask returns None when subtask key is missing from complementary_data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task"}, # No "subtask" key
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_none_value():
|
||||
"""Test get_subtask returns None when subtask value is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": None},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_none_complementary_data():
|
||||
"""Test get_subtask returns None when complementary_data is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data=None, # No complementary data
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_string():
|
||||
"""Test get_subtask returns list with single string when subtask is a string."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up the cube"},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == ["pick up the cube"]
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_list_of_strings():
|
||||
"""Test get_subtask returns the list when subtask is already a list of strings."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
subtask_list = ["pick up", "move to target", "place down"]
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": subtask_list},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == subtask_list
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_integer():
|
||||
"""Test get_subtask returns None when subtask is an unsupported type (integer)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": 123},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_mixed_list():
|
||||
"""Test get_subtask returns None when subtask is a list with mixed types."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_unsupported_type_dict():
|
||||
"""Test get_subtask returns None when subtask is a dictionary."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": {"key": "value"}},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result is None
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_empty_string():
|
||||
"""Test get_subtask with empty string returns list with empty string."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ""},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == [""]
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_get_subtask_empty_list():
|
||||
"""Test get_subtask with empty list returns empty list."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": []},
|
||||
)
|
||||
|
||||
result = processor.get_subtask(transition)
|
||||
assert result == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for subtask tokenization in observation method
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_when_present():
|
||||
"""Test that subtask is tokenized and added to observation when present."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up the red cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check token structure
|
||||
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert isinstance(subtask_tokens, torch.Tensor)
|
||||
assert isinstance(subtask_attention_mask, torch.Tensor)
|
||||
assert subtask_tokens.shape == (8,)
|
||||
assert subtask_attention_mask.shape == (8,)
|
||||
assert subtask_attention_mask.dtype == torch.bool
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_not_added_when_none():
|
||||
"""Test that subtask tokens are NOT added to observation when subtask is None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task"}, # No subtask
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were NOT added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
# But main task tokens should still be present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_not_added_when_subtask_value_is_none():
|
||||
"""Test that subtask tokens are NOT added when subtask value is explicitly None."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": None},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were NOT added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_list_of_strings():
|
||||
"""Test subtask tokenization with list of strings."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": ["pick up", "place down"]},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens were added to observation
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check token structure for batch
|
||||
subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8
|
||||
assert subtask_attention_mask.shape == (2, 8)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_device_cpu():
|
||||
"""Test that subtask tokens are on CPU when other tensors are on CPU."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CPU tensors
|
||||
observation = {OBS_STATE: torch.randn(10)} # CPU tensor
|
||||
action = torch.randn(5) # CPU tensor
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens are on CPU
|
||||
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
assert subtask_tokens.device.type == "cpu"
|
||||
assert subtask_attention_mask.device.type == "cpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_device_cuda():
|
||||
"""Test that subtask tokens are moved to CUDA when other tensors are on CUDA."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
# Create transition with CUDA tensors
|
||||
observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor
|
||||
action = torch.randn(5).cuda() # CUDA tensor
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
# Check that subtask tokens are on CUDA
|
||||
subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
assert subtask_tokens.device.type == "cuda"
|
||||
assert subtask_attention_mask.device.type == "cuda"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_preserves_other_observation_data():
|
||||
"""Test that subtask tokenization preserves other observation data."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
original_state = torch.tensor([1.0, 2.0, 3.0])
|
||||
transition = create_transition(
|
||||
observation={"state": original_state.clone()},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that original observation data is preserved
|
||||
assert torch.equal(observation["state"], original_state)
|
||||
|
||||
# Check that both task and subtask tokens are present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_attention_mask_dtype():
|
||||
"""Test that subtask attention mask has correct dtype (bool)."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "pick up cube"},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
assert subtask_attention_mask.dtype == torch.bool
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_tokenization_deterministic():
|
||||
"""Test that subtask tokenization is deterministic for the same input."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "consistent subtask"},
|
||||
)
|
||||
|
||||
result1 = processor(transition)
|
||||
result2 = processor(transition)
|
||||
|
||||
subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS]
|
||||
subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK]
|
||||
|
||||
# Results should be identical
|
||||
assert torch.equal(subtask_tokens1, subtask_tokens2)
|
||||
assert torch.equal(subtask_mask1, subtask_mask2)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@patch("lerobot.processor.tokenizer_processor.AutoTokenizer")
|
||||
def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer):
|
||||
"""Test subtask tokenization works correctly with DataProcessorPipeline."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = DataProcessorPipeline(
|
||||
[tokenizer_processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": "subtask instruction"},
|
||||
)
|
||||
|
||||
result = robot_processor(transition)
|
||||
|
||||
# Check that observation exists and both tokenizations were applied
|
||||
assert TransitionKey.OBSERVATION in result
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check task tokens
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
assert f"{OBS_LANGUAGE}.attention_mask" in observation
|
||||
|
||||
# Check subtask tokens
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation
|
||||
|
||||
# Check shapes
|
||||
assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,)
|
||||
assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_subtask_not_added_for_unsupported_types():
|
||||
"""Test that subtask tokens are not added when subtask has unsupported type."""
|
||||
mock_tokenizer = MockTokenizer(vocab_size=100)
|
||||
processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8)
|
||||
|
||||
# Test with integer subtask
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
action=torch.tensor([0.1, 0.2]),
|
||||
complementary_data={"task": "main task", "subtask": 123},
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
observation = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Subtask tokens should NOT be added for unsupported types
|
||||
assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation
|
||||
assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation
|
||||
|
||||
# But main task tokens should still be present
|
||||
assert f"{OBS_LANGUAGE}.tokens" in observation
|
||||
|
||||
Reference in New Issue
Block a user