Compare commits

...

1 Commits

Author SHA1 Message Date
Steven Palma 49cb1ee7db chore(policies): add explicit dataset dependecy to gr00t implementation 2026-06-30 16:49:08 +02:00
5 changed files with 17 additions and 3 deletions
+1
View File
@@ -218,6 +218,7 @@ groot = [
"lerobot[transformers-dep]",
"lerobot[peft-dep]",
"lerobot[diffusers-dep]",
"lerobot[dataset]", # NOTE: processor_groot builds a LeRobotDataset for relative-action training stats
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
@@ -27,7 +27,7 @@ import torchvision.transforms.v2.functional as tv_functional
from einops import rearrange
from torchvision.transforms import InterpolationMode
from lerobot.utils.import_utils import _transformers_available, require_package
from lerobot.utils.import_utils import _datasets_available, _transformers_available, require_package
if TYPE_CHECKING or _transformers_available:
from transformers import (
@@ -44,6 +44,11 @@ else:
Qwen3VLProcessor = None
Qwen3VLVideoProcessor = None
if TYPE_CHECKING or _datasets_available:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
else:
LeRobotDataset = None
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
@@ -811,7 +816,7 @@ def _make_relative_action_training_stats_from_dataset_meta(
if dataset_meta is None or repo_id is None or root is None or fps is None:
return None
from lerobot.datasets.lerobot_dataset import LeRobotDataset
require_package("datasets", extra="groot")
delta_timestamps = {ACTION: [index / fps for index in config.action_delta_indices]}
dataset = LeRobotDataset(
+1
View File
@@ -129,6 +129,7 @@ _placo_available = is_package_available("placo")
_hidapi_available = is_package_available("hidapi", import_name="hid")
# Data / serialization
_datasets_available = is_package_available("datasets")
_pandas_available = is_package_available("pandas")
_faker_available = is_package_available("faker")
+1 -1
View File
@@ -2256,7 +2256,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
assert kwargs["delta_timestamps"][ACTION] == [0.0, 1 / runtime_meta.fps]
return _RelativeStatsDataset()
monkeypatch.setattr("lerobot.datasets.lerobot_dataset.LeRobotDataset", _fake_lerobot_dataset)
monkeypatch.setattr("lerobot.policies.groot.processor_groot.LeRobotDataset", _fake_lerobot_dataset)
config._runtime_dataset_meta = runtime_meta
preprocessor, postprocessor = make_groot_pre_post_processors(config, dataset_stats=absolute_dataset_stats)
Generated
+7
View File
@@ -2957,11 +2957,17 @@ gamepad = [
{ name = "pygame" },
]
groot = [
{ name = "av" },
{ name = "datasets" },
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
{ name = "diffusers" },
{ name = "dm-tree" },
{ name = "jsonlines" },
{ name = "pandas" },
{ name = "peft" },
{ name = "pyarrow" },
{ name = "timm" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "transformers" },
]
grpcio-dep = [
@@ -3240,6 +3246,7 @@ requires-dist = [
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'annotations'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'core-scripts'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'dataset-viz'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'groot'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'hilserl'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'libero'" },
{ name = "lerobot", extras = ["dataset"], marker = "extra == 'metaworld'" },