mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
Add extensive language support
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
|
||||
|
||||
def test_message_recipe_validates_unknown_binding():
|
||||
with pytest.raises(ValueError, match="unknown binding"):
|
||||
TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${missing}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_canonical_recipe_loads():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
|
||||
assert recipe.blend is not None
|
||||
assert set(recipe.blend) == {
|
||||
"memory_update",
|
||||
"user_interjection_response",
|
||||
"high_level_subtask",
|
||||
"low_level_execution",
|
||||
"ask_vqa",
|
||||
}
|
||||
assert sum(component.weight for component in recipe.blend.values()) == pytest.approx(0.96)
|
||||
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.io_utils import write_info
|
||||
from lerobot.datasets.language import (
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
STYLE_REGISTRY,
|
||||
column_for_style,
|
||||
language_events_arrow_type,
|
||||
language_feature_info,
|
||||
language_persistent_arrow_type,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
||||
|
||||
|
||||
def test_language_arrow_schema_has_expected_fields():
|
||||
row_type = language_persistent_arrow_type().value_type
|
||||
|
||||
assert isinstance(row_type, pa.StructType)
|
||||
assert row_type.names == ["role", "content", "style", "timestamp", "tool_calls"]
|
||||
assert language_events_arrow_type().value_type == row_type
|
||||
|
||||
|
||||
def test_style_registry_routes_columns():
|
||||
assert {"subtask", "plan", "memory"} == PERSISTENT_STYLES
|
||||
assert {"interjection", "vqa"} == EVENT_ONLY_STYLES
|
||||
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
||||
|
||||
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
||||
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
||||
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
||||
assert column_for_style(None) == LANGUAGE_EVENTS
|
||||
|
||||
|
||||
def test_unknown_style_rejected():
|
||||
with pytest.raises(ValueError, match="Unknown language style"):
|
||||
column_for_style("surprise")
|
||||
|
||||
|
||||
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
|
||||
root = tmp_path / "language_dataset"
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=root,
|
||||
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
|
||||
use_videos=False,
|
||||
)
|
||||
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
persistent = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach for the cup",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
]
|
||||
event = {
|
||||
"role": "user",
|
||||
"content": "what is visible?",
|
||||
"style": "vqa",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
df = pd.read_parquet(data_path)
|
||||
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
|
||||
df[LANGUAGE_EVENTS] = [[event], []]
|
||||
df.to_parquet(data_path)
|
||||
|
||||
info = dataset.meta.info
|
||||
info["features"].update(language_feature_info())
|
||||
write_info(info, root)
|
||||
|
||||
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
|
||||
|
||||
first = reloaded[0]
|
||||
second = reloaded[1]
|
||||
assert first[LANGUAGE_PERSISTENT] == persistent
|
||||
assert first[LANGUAGE_EVENTS] == [event]
|
||||
assert second[LANGUAGE_PERSISTENT] == persistent
|
||||
assert second[LANGUAGE_EVENTS] == []
|
||||
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
||||
|
||||
|
||||
def row(role, content, style, timestamp, tool_calls=None):
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"style": style,
|
||||
"timestamp": timestamp,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
|
||||
|
||||
PERSISTENT = [
|
||||
row("assistant", "plan 0", "plan", 0.0),
|
||||
row("assistant", "memory 0", "memory", 0.0),
|
||||
row("assistant", "subtask 0", "subtask", 0.0),
|
||||
row("assistant", "memory 1", "memory", 1.0),
|
||||
row("assistant", "subtask 1", "subtask", 1.0),
|
||||
]
|
||||
EVENTS = [
|
||||
row("user", "what is visible?", "vqa", 1.0),
|
||||
row("assistant", '{"count": 2}', "vqa", 1.0),
|
||||
row("user", "skip wiping", "interjection", 2.0),
|
||||
row(
|
||||
"assistant",
|
||||
None,
|
||||
None,
|
||||
2.0,
|
||||
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_resolver_temporal_semantics():
|
||||
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
||||
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
||||
assert emitted_at(0.5, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant") is None
|
||||
assert (
|
||||
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS, style="vqa", role="assistant")["content"]
|
||||
== '{"count": 2}'
|
||||
)
|
||||
|
||||
|
||||
def test_persistent_relative_resolvers_reject_event_styles():
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
active_at(1.0, persistent=PERSISTENT, style="vqa")
|
||||
with pytest.raises(ValueError, match="event-only"):
|
||||
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
|
||||
|
||||
|
||||
def test_nth_prev_and_next():
|
||||
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
|
||||
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
|
||||
|
||||
|
||||
def test_substitution_if_present_multimodal_and_tool_calls():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${task}: ${interjection}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="interjection",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${plan}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
tool_calls_from="speech",
|
||||
),
|
||||
],
|
||||
bindings={"plan": "active_at(t, style=plan)"},
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=recipe,
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS,
|
||||
t=2.0,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
|
||||
assert rendered["messages"][1]["content"] == "plan 0"
|
||||
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
|
||||
assert rendered["message_streams"] == ["high_level", "high_level"]
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_exact_event_miss_returns_none_when_target_skips():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=0) is None
|
||||
|
||||
|
||||
def test_deterministic_blend_sampling():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"a": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
"b": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
first = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
second = render_sample(
|
||||
recipe=recipe, persistent=PERSISTENT, events=EVENTS, t=0.0, sample_idx=123, task="x"
|
||||
)
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=low_level,
|
||||
persistent=PERSISTENT,
|
||||
events=[],
|
||||
t=0.5,
|
||||
sample_idx=0,
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
||||
assert rendered["message_streams"][-1] == "low_level"
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
@@ -1,193 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for subtask functionality in LeRobotDataset.
|
||||
|
||||
These tests verify that:
|
||||
- Subtask information is correctly loaded from datasets that have subtask data
|
||||
- The __getitem__ method correctly adds subtask strings to returned items
|
||||
- Subtask handling gracefully handles missing data
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import pandas as pd # noqa: E402
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class TestSubtaskDataset:
|
||||
"""Tests for subtask handling in LeRobotDataset."""
|
||||
|
||||
@pytest.fixture
|
||||
def subtask_dataset(self):
|
||||
"""Load the test subtask dataset from the hub."""
|
||||
# Use lerobot/pusht-subtask dataset with episode 1
|
||||
return LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
|
||||
def test_subtask_dataset_loads(self, subtask_dataset):
|
||||
"""Test that the subtask dataset loads successfully."""
|
||||
assert subtask_dataset is not None
|
||||
assert len(subtask_dataset) > 0
|
||||
|
||||
def test_subtask_metadata_loaded(self, subtask_dataset):
|
||||
"""Test that subtask metadata is loaded when present in dataset."""
|
||||
# The dataset should have subtasks metadata loaded
|
||||
assert subtask_dataset.meta.subtasks is not None
|
||||
assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame)
|
||||
|
||||
def test_subtask_index_in_features(self, subtask_dataset):
|
||||
"""Test that subtask_index is a feature when dataset has subtasks."""
|
||||
assert "subtask_index" in subtask_dataset.features
|
||||
|
||||
def test_getitem_returns_subtask_string(self, subtask_dataset):
|
||||
"""Test that __getitem__ correctly adds subtask string to returned item."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Subtask should be present in the returned item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
assert len(item["subtask"]) > 0 # Should not be empty
|
||||
|
||||
def test_getitem_has_subtask_index(self, subtask_dataset):
|
||||
"""Test that __getitem__ includes subtask_index."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
assert "subtask_index" in item
|
||||
assert isinstance(item["subtask_index"], torch.Tensor)
|
||||
|
||||
def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset):
|
||||
"""Test that subtask_index correctly maps to a subtask in metadata."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
subtask_idx = item["subtask_index"].item()
|
||||
subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name
|
||||
|
||||
assert item["subtask"] == subtask_from_metadata
|
||||
|
||||
def test_all_items_have_subtask(self, subtask_dataset):
|
||||
"""Test that all items in the dataset have subtask information."""
|
||||
for i in range(min(len(subtask_dataset), 5)): # Check first 5 items
|
||||
item = subtask_dataset[i]
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
def test_task_and_subtask_coexist(self, subtask_dataset):
|
||||
"""Test that both task and subtask are present in returned items."""
|
||||
item = subtask_dataset[0]
|
||||
|
||||
# Both task and subtask should be present
|
||||
assert "task" in item
|
||||
assert "subtask" in item
|
||||
assert isinstance(item["task"], str)
|
||||
assert isinstance(item["subtask"], str)
|
||||
|
||||
|
||||
class TestSubtaskDatasetMissing:
|
||||
"""Tests for graceful handling when subtask data is missing."""
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Create a dataset without subtask information."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features)
|
||||
|
||||
# Add some frames and save
|
||||
for _ in range(5):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Test task"})
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Reload the dataset
|
||||
return LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
def test_no_subtask_in_features(self, dataset_without_subtasks):
|
||||
"""Test that subtask_index is not in features when not provided."""
|
||||
assert "subtask_index" not in dataset_without_subtasks.features
|
||||
|
||||
def test_getitem_without_subtask(self, dataset_without_subtasks):
|
||||
"""Test that __getitem__ works when subtask is not present."""
|
||||
item = dataset_without_subtasks[0]
|
||||
|
||||
# Item should still be retrievable
|
||||
assert item is not None
|
||||
assert "state" in item
|
||||
assert "task" in item
|
||||
|
||||
# Subtask should NOT be present
|
||||
assert "subtask" not in item
|
||||
|
||||
def test_subtasks_metadata_is_none(self, dataset_without_subtasks):
|
||||
"""Test that subtasks metadata is None when not present."""
|
||||
assert dataset_without_subtasks.meta.subtasks is None
|
||||
|
||||
|
||||
class TestSubtaskEdgeCases:
|
||||
"""Edge case tests for subtask handling."""
|
||||
|
||||
def test_subtask_with_multiple_episodes(self):
|
||||
"""Test subtask handling with multiple episodes if available."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
# Check first and last items have valid subtasks
|
||||
first_item = dataset[0]
|
||||
last_item = dataset[len(dataset) - 1]
|
||||
|
||||
assert "subtask" in first_item
|
||||
assert "subtask" in last_item
|
||||
assert isinstance(first_item["subtask"], str)
|
||||
assert isinstance(last_item["subtask"], str)
|
||||
|
||||
def test_subtask_index_consistency(self):
|
||||
"""Test that same subtask_index returns same subtask string."""
|
||||
try:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id="lerobot/pusht-subtask",
|
||||
episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
|
||||
)
|
||||
except Exception:
|
||||
pytest.skip("Could not load test-subtask dataset")
|
||||
|
||||
if len(dataset) < 2:
|
||||
pytest.skip("Dataset too small for this test")
|
||||
|
||||
# Collect subtask_index to subtask mappings
|
||||
subtask_map = {}
|
||||
for i in range(min(len(dataset), 10)):
|
||||
item = dataset[i]
|
||||
idx = item["subtask_index"].item()
|
||||
subtask = item["subtask"]
|
||||
|
||||
if idx in subtask_map:
|
||||
# Same index should always return same subtask
|
||||
assert subtask_map[idx] == subtask, (
|
||||
f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'"
|
||||
)
|
||||
else:
|
||||
subtask_map[idx] = subtask
|
||||
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.processor.converters import create_transition
|
||||
from lerobot.processor.render_messages_processor import RenderMessagesStep
|
||||
from lerobot.types import TransitionKey
|
||||
|
||||
|
||||
def test_render_messages_step_noops_without_language_columns():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(complementary_data={"task": "do it"})
|
||||
|
||||
assert RenderMessagesStep(recipe)(transition) == transition
|
||||
|
||||
|
||||
def test_render_messages_step_renders_and_drops_raw_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "do it",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "reach carefully",
|
||||
"style": "subtask",
|
||||
"timestamp": 0.0,
|
||||
"tool_calls": None,
|
||||
}
|
||||
],
|
||||
"language_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert "language_persistent" not in data
|
||||
assert "language_events" not in data
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
|
||||
|
||||
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
|
||||
batch = [
|
||||
{
|
||||
"index": torch.tensor(0),
|
||||
"messages": [{"role": "assistant", "content": "a"}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [0],
|
||||
"language_persistent": [{"content": "raw"}],
|
||||
"language_events": [],
|
||||
},
|
||||
{
|
||||
"index": torch.tensor(1),
|
||||
"messages": [{"role": "assistant", "content": "b"}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [0],
|
||||
"language_persistent": [{"content": "raw"}],
|
||||
"language_events": [],
|
||||
},
|
||||
]
|
||||
|
||||
out = lerobot_collate_fn(batch)
|
||||
|
||||
assert out["index"].tolist() == [0, 1]
|
||||
assert out["messages"][0][0]["content"] == "a"
|
||||
assert out["messages"][1][0]["content"] == "b"
|
||||
assert out["message_streams"] == [["low_level"], ["low_level"]]
|
||||
assert out["target_message_indices"] == [[0], [0]]
|
||||
assert "language_persistent" not in out
|
||||
assert "language_events" not in out
|
||||
Reference in New Issue
Block a user