mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-10 13:19:53 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 | |||
| 49755a3d9e |
+3
-3
@@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
|
||||
@@ -32,7 +32,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
_serialized_state_filenames: tuple[str | None, ...] | None = field(
|
||||
default=None,
|
||||
init=False,
|
||||
repr=False,
|
||||
)
|
||||
|
||||
def __call__(self, data: TInput) -> TOutput:
|
||||
"""Processes input data through the full pipeline.
|
||||
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs):
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
def _get_sanitized_name(self) -> str:
|
||||
"""Return a filename-safe version of the pipeline name.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
Returns:
|
||||
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
|
||||
# Sanitize the pipeline name to create a valid filename prefix.
|
||||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
|
||||
@staticmethod
|
||||
def _get_state_filename(
|
||||
*,
|
||||
step_index: int,
|
||||
registry_name: str | None,
|
||||
sanitized_name: str,
|
||||
) -> str:
|
||||
"""Return the safetensors filename for one stateful processor step.
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
Args:
|
||||
step_index: The index of the processor step in this pipeline.
|
||||
registry_name: The registered processor step name, if available.
|
||||
sanitized_name: The filename-safe pipeline name.
|
||||
|
||||
config: dict[str, Any] = {
|
||||
Returns:
|
||||
The state filename used by the existing disk serialization format.
|
||||
"""
|
||||
if registry_name:
|
||||
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
|
||||
return f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
|
||||
@staticmethod
|
||||
def _get_state_key(state_filename: str) -> str:
|
||||
"""Return the in-memory state key for a serialized state filename.
|
||||
|
||||
Args:
|
||||
state_filename: The `.safetensors` filename from the serialized config.
|
||||
|
||||
Returns:
|
||||
The state key used by the in-memory pipeline state dictionary.
|
||||
"""
|
||||
return state_filename.removesuffix(".safetensors")
|
||||
|
||||
@staticmethod
|
||||
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
|
||||
"""Return serialized state filenames in step order.
|
||||
|
||||
Args:
|
||||
loaded_config: A validated processor pipeline config.
|
||||
|
||||
Returns:
|
||||
A tuple containing each step's serialized state filename, or None for stateless steps.
|
||||
"""
|
||||
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
|
||||
|
||||
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
|
||||
"""Return expected state filenames in step order for `load_state_dict()`.
|
||||
|
||||
Returns:
|
||||
The preserved serialized state filenames when available, otherwise filenames derived from
|
||||
current non-empty step state.
|
||||
"""
|
||||
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
|
||||
self.steps
|
||||
):
|
||||
return self._serialized_state_filenames
|
||||
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
state_filenames: list[str | None] = []
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
state_filenames.append(None)
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filenames.append(
|
||||
self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(state_filenames)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the JSON-serializable pipeline configuration.
|
||||
|
||||
Returns:
|
||||
A dictionary with the same content that `save_pretrained()` writes as JSON.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_config: dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
# Iterate through each step to build its configuration entry.
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
|
||||
step_entry: dict[str, Any] = {}
|
||||
# Prefer registry name for portability, otherwise fall back to full class path.
|
||||
|
||||
if registry_name:
|
||||
step_entry["registry_name"] = registry_name
|
||||
else:
|
||||
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
|
||||
)
|
||||
|
||||
# Save step configuration if `get_config` is implemented.
|
||||
if hasattr(processor_step, "get_config"):
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
step_entry["config"] = processor_step.get_config()
|
||||
|
||||
# Save step state if `state_dict` is implemented and returns a non-empty dict.
|
||||
if hasattr(processor_step, "state_dict"):
|
||||
state = processor_step.state_dict()
|
||||
if state:
|
||||
# Clone tensors to avoid modifying the original state.
|
||||
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if step_state_dict:
|
||||
step_entry["state_file"] = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
|
||||
# Create a unique filename for the state file.
|
||||
if registry_name:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
|
||||
else:
|
||||
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
|
||||
pipeline_config["steps"].append(step_entry)
|
||||
|
||||
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
|
||||
step_entry["state_file"] = state_filename
|
||||
return pipeline_config
|
||||
|
||||
config["steps"].append(step_entry)
|
||||
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""Return pipeline state tensors grouped by state key.
|
||||
|
||||
# Write the main configuration JSON file.
|
||||
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
|
||||
json.dump(config, file_pointer, indent=2)
|
||||
Returns:
|
||||
A dictionary mapping suffixless state keys to cloned step state dictionaries.
|
||||
"""
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for step_index, processor_step in enumerate(self.steps):
|
||||
step_state_dict = processor_step.state_dict()
|
||||
if not step_state_dict:
|
||||
continue
|
||||
|
||||
registry_name = getattr(processor_step.__class__, "_registry_name", None)
|
||||
state_filename = self._get_state_filename(
|
||||
step_index=step_index,
|
||||
registry_name=registry_name,
|
||||
sanitized_name=sanitized_name,
|
||||
)
|
||||
state_key = self._get_state_key(state_filename)
|
||||
pipeline_state_dict[state_key] = {
|
||||
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
|
||||
}
|
||||
|
||||
return pipeline_state_dict
|
||||
|
||||
def load_state_dict(
|
||||
self,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]],
|
||||
) -> None:
|
||||
"""Load pipeline state tensors into the existing steps.
|
||||
|
||||
Args:
|
||||
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
|
||||
|
||||
Raises:
|
||||
KeyError: If loading finds missing expected state or unexpected extra state.
|
||||
"""
|
||||
expected_state_filenames = self._get_state_filenames_for_loading()
|
||||
used_state_keys: set[str] = set()
|
||||
|
||||
for step_index, (processor_step, state_filename) in enumerate(
|
||||
zip(self.steps, expected_state_filenames, strict=True)
|
||||
):
|
||||
if state_filename is None:
|
||||
continue
|
||||
|
||||
state_key = self._get_state_key(state_filename)
|
||||
if state_key not in state_dict:
|
||||
raise KeyError(
|
||||
f"Missing state key '{state_key}' for processor step {step_index}. "
|
||||
f"Available state keys: {sorted(state_dict.keys())}"
|
||||
)
|
||||
|
||||
processor_step.load_state_dict(state_dict[state_key])
|
||||
used_state_keys.add(state_key)
|
||||
|
||||
unexpected_state_keys = set(state_dict) - used_state_keys
|
||||
if unexpected_state_keys:
|
||||
expected_state_key_set = {
|
||||
self._get_state_key(state_filename)
|
||||
for state_filename in expected_state_filenames
|
||||
if state_filename is not None
|
||||
}
|
||||
raise KeyError(
|
||||
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
|
||||
f"Expected state keys: {sorted(expected_state_key_set)}"
|
||||
)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
|
||||
"""Internal method to comply with `HubMixin`'s saving mechanism.
|
||||
|
||||
This method does the actual saving work and is called by HubMixin.save_pretrained.
|
||||
"""
|
||||
config_filename = kwargs.pop("config_filename", None)
|
||||
sanitized_name = self._get_sanitized_name()
|
||||
|
||||
if config_filename is None:
|
||||
config_filename = f"{sanitized_name}.json"
|
||||
|
||||
pipeline_config = self.get_config()
|
||||
pipeline_state_dict = self.state_dict()
|
||||
|
||||
for state_key, step_state_dict in pipeline_state_dict.items():
|
||||
state_filename = f"{state_key}.safetensors"
|
||||
save_file(step_state_dict, save_directory / state_filename)
|
||||
|
||||
with open(save_directory / config_filename, "w") as file_pointer:
|
||||
json.dump(pipeline_config, file_pointer, indent=2)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
cls._validate_overrides_used(validated_overrides, loaded_config)
|
||||
|
||||
# 5. Construct and return the final pipeline instance
|
||||
return cls(
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: dict[str, Any],
|
||||
*,
|
||||
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[TInput], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
) -> DataProcessorPipeline[TInput, TOutput]:
|
||||
"""Build a pipeline from an in-memory config and optional state tensors.
|
||||
|
||||
Args:
|
||||
config: A config dictionary with the same structure as the saved processor JSON.
|
||||
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
|
||||
overrides: Optional constructor overrides keyed by registry name or class name.
|
||||
to_transition: Optional converter from input data to `EnvTransition`.
|
||||
to_output: Optional converter from `EnvTransition` to output data.
|
||||
|
||||
Returns:
|
||||
A processor pipeline built from the config and optional state.
|
||||
"""
|
||||
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
|
||||
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
|
||||
cls._validate_overrides_used(remaining_override_keys, config)
|
||||
|
||||
pipeline = cls(
|
||||
steps=steps,
|
||||
name=config.get("name", "DataProcessorPipeline"),
|
||||
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
|
||||
)
|
||||
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
pipeline.load_state_dict(state_dict)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def _load_config(
|
||||
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def _validate_loaded_config(
|
||||
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
|
||||
) -> None:
|
||||
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
|
||||
"""Validate that a config was loaded and is a valid processor config.
|
||||
|
||||
This method validates processor config format with intelligent migration detection:
|
||||
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (used for migration detection)
|
||||
loaded_config: The loaded config dictionary (guaranteed non-None)
|
||||
loaded_config: The loaded config value to validate (may be non-dict)
|
||||
config_filename: The config filename that was loaded (for error messages)
|
||||
|
||||
Raises:
|
||||
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
model_id,
|
||||
f"Config file '{config_filename}' is not a valid processor configuration",
|
||||
)
|
||||
loaded_config_description = (
|
||||
list(loaded_config.keys())
|
||||
if isinstance(loaded_config, dict)
|
||||
else type(loaded_config).__name__
|
||||
)
|
||||
raise ValueError(
|
||||
f"Config file '{config_filename}' is not a valid processor configuration. "
|
||||
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
|
||||
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
ImportError: If a step class cannot be imported or found in registry
|
||||
ValueError: If a step cannot be instantiated with its configuration
|
||||
"""
|
||||
steps: list[ProcessorStep] = []
|
||||
override_keys = set(overrides.keys())
|
||||
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
# 1. Get step class and key
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
|
||||
# 2. Instantiate step with overrides
|
||||
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
# 3. Load step state if available
|
||||
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
|
||||
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
|
||||
|
||||
# 4. Track used overrides
|
||||
if step_key in override_keys:
|
||||
override_keys.discard(step_key)
|
||||
return steps, remaining_override_keys
|
||||
|
||||
steps.append(step_instance)
|
||||
@classmethod
|
||||
def _build_steps_from_config(
|
||||
cls,
|
||||
loaded_config: dict[str, Any],
|
||||
overrides: dict[str, Any],
|
||||
) -> tuple[list[ProcessorStep], set[str]]:
|
||||
"""Build processor steps from config without loading tensor state.
|
||||
|
||||
return steps, override_keys
|
||||
Args:
|
||||
loaded_config: The loaded processor configuration.
|
||||
overrides: User-provided constructor overrides keyed by step key.
|
||||
|
||||
Returns:
|
||||
A tuple containing instantiated steps and override keys that did not match a step.
|
||||
"""
|
||||
processor_steps: list[ProcessorStep] = []
|
||||
remaining_override_keys = set(overrides.keys())
|
||||
|
||||
for step_entry in loaded_config["steps"]:
|
||||
step_class, step_key = cls._resolve_step_class(step_entry)
|
||||
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
|
||||
|
||||
if step_key in remaining_override_keys:
|
||||
remaining_override_keys.discard(step_key)
|
||||
|
||||
processor_steps.append(processor_step)
|
||||
|
||||
return processor_steps, remaining_override_keys
|
||||
|
||||
@classmethod
|
||||
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
|
||||
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_processor_config(cls, config: dict) -> bool:
|
||||
def _is_processor_config(cls, config: Any) -> bool:
|
||||
"""Check if config follows DataProcessorPipeline format.
|
||||
|
||||
This method validates the processor configuration structure:
|
||||
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
|
||||
Returns:
|
||||
True if config follows valid DataProcessorPipeline format, False otherwise
|
||||
"""
|
||||
if not isinstance(config, dict):
|
||||
return False
|
||||
|
||||
# Must have a "steps" field with a list of step configurations
|
||||
if not isinstance(config.get("steps"), list):
|
||||
return False
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import Any
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
class MockLazyTensorStateStep(ProcessorStep):
|
||||
"""Mock step whose tensor state is not present in constructor config."""
|
||||
|
||||
def __init__(
|
||||
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
|
||||
):
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
self.tensor_state: torch.Tensor | None = None
|
||||
|
||||
if initial_value is not None:
|
||||
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Return the transition unchanged."""
|
||||
return transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return constructor config while intentionally omitting tensor state."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"scale": self.scale,
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, torch.Tensor]:
|
||||
"""Return tensor state only after it has been initialized or loaded."""
|
||||
if self.tensor_state is None:
|
||||
return {}
|
||||
|
||||
return {"tensor_state": self.tensor_state}
|
||||
|
||||
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
|
||||
"""Load tensor state."""
|
||||
self.tensor_state = state["tensor_state"].clone()
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Return features unchanged."""
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
|
||||
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
|
||||
"""Registered lazy tensor state step for registry-based serialization tests."""
|
||||
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
|
||||
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
|
||||
assert torch.allclose(loaded_step.running_mean, step.running_mean)
|
||||
|
||||
|
||||
def test_get_config_matches_saved_json():
|
||||
"""Test that in-memory config matches the config written by save_pretrained."""
|
||||
stateless_step = MockStep(name="stateless")
|
||||
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
|
||||
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
|
||||
|
||||
in_memory_config = pipeline.get_config()
|
||||
|
||||
assert pipeline.get_config() == in_memory_config
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
config_path = Path(tmp_dir) / "memory_pipeline.json"
|
||||
with open(config_path) as file_pointer:
|
||||
saved_config = json.load(file_pointer)
|
||||
|
||||
assert in_memory_config == saved_config
|
||||
assert "state_file" not in in_memory_config["steps"][0]
|
||||
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
|
||||
|
||||
|
||||
def test_state_dict_matches_saved_safetensors():
|
||||
"""Test that in-memory state matches the safetensors written by save_pretrained."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
|
||||
|
||||
in_memory_state_dict = pipeline.state_dict()
|
||||
state_filename = "stateful_pipeline_step_0.safetensors"
|
||||
state_key = "stateful_pipeline_step_0"
|
||||
|
||||
assert set(in_memory_state_dict) == {state_key}
|
||||
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
|
||||
|
||||
in_memory_state_dict[state_key]["tensor_state"].add_(1)
|
||||
assert stateful_step.tensor_state is not None
|
||||
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
|
||||
|
||||
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
|
||||
|
||||
|
||||
def test_save_pretrained_still_writes_expected_serialization_files():
|
||||
"""Test that save_pretrained keeps the existing config and state filenames."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
save_path = Path(tmp_dir)
|
||||
assert (save_path / "policy_preprocessor.json").exists()
|
||||
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
|
||||
|
||||
|
||||
def test_from_config_round_trips_stateful_pipeline():
|
||||
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
|
||||
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
|
||||
|
||||
|
||||
def test_from_config_round_trips_registered_stateful_pipeline():
|
||||
"""Test that from_config resolves registry steps and loads their named tensor state."""
|
||||
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
|
||||
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
|
||||
|
||||
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
|
||||
assert config["steps"][0]["state_file"] == state_filename
|
||||
assert set(pipeline_state_dict) == {state_key}
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
|
||||
assert loaded_step.tensor_state is not None
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
|
||||
|
||||
|
||||
def test_from_config_preserves_state_metadata_for_empty_initial_state():
|
||||
"""Test in-memory loading when rebuilt steps start without tensor state."""
|
||||
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.state_dict() == {}
|
||||
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
|
||||
|
||||
loaded_pipeline.load_state_dict(pipeline_state_dict)
|
||||
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
|
||||
|
||||
|
||||
def test_from_config_applies_overrides_before_state_loading():
|
||||
"""Test that constructor overrides and tensor state loading are separate operations."""
|
||||
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
|
||||
config = pipeline.get_config()
|
||||
pipeline_state_dict = pipeline.state_dict()
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(
|
||||
config,
|
||||
state_dict=pipeline_state_dict,
|
||||
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
|
||||
)
|
||||
loaded_step = loaded_pipeline.steps[0]
|
||||
|
||||
assert isinstance(loaded_step, MockLazyTensorStateStep)
|
||||
assert loaded_step.scale == 5.0
|
||||
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_missing_expected_state():
|
||||
"""Test loading raises when serialized config expects missing state."""
|
||||
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
|
||||
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
|
||||
|
||||
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
|
||||
loaded_pipeline.load_state_dict({})
|
||||
|
||||
|
||||
def test_load_state_dict_raises_on_unexpected_extra_state():
|
||||
"""Test loading raises on unexpected top-level state keys."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
|
||||
|
||||
with pytest.raises(KeyError, match="extra"):
|
||||
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
|
||||
|
||||
|
||||
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
|
||||
"""Test stateless in-memory serialization and loading."""
|
||||
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
|
||||
config = pipeline.get_config()
|
||||
config_without_name = {"steps": config["steps"]}
|
||||
|
||||
assert pipeline.state_dict() == {}
|
||||
assert all("state_file" not in step_entry for step_entry in config["steps"])
|
||||
|
||||
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
|
||||
|
||||
assert loaded_pipeline.name == "DataProcessorPipeline"
|
||||
assert loaded_pipeline.state_dict() == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
|
||||
def test_from_config_rejects_non_dict_config(invalid_config):
|
||||
"""Test from_config reports invalid top-level config values cleanly."""
|
||||
with pytest.raises(ValueError, match="not a valid processor configuration"):
|
||||
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class MockModuleStep(ProcessorStep, nn.Module):
|
||||
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
|
||||
|
||||
|
||||
@@ -1764,7 +1764,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "gym-aloha"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dm-control" },
|
||||
@@ -1772,14 +1772,14 @@ dependencies = [
|
||||
{ name = "imageio", extra = ["ffmpeg"] },
|
||||
{ name = "mujoco" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-hil"
|
||||
version = "0.1.13"
|
||||
version = "0.1.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "gymnasium" },
|
||||
@@ -1789,9 +1789,9 @@ dependencies = [
|
||||
{ name = "pygame" },
|
||||
{ name = "pynput" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1881,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
|
||||
|
||||
[[package]]
|
||||
name = "hf-libero"
|
||||
version = "0.1.3"
|
||||
version = "0.1.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bddl", marker = "sys_platform == 'linux'" },
|
||||
@@ -1902,7 +1902,10 @@ dependencies = [
|
||||
{ name = "transformers", marker = "sys_platform == 'linux'" },
|
||||
{ name = "wandb", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
@@ -3090,12 +3093,12 @@ requires-dist = [
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
|
||||
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
|
||||
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
|
||||
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
|
||||
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
|
||||
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },
|
||||
|
||||
Reference in New Issue
Block a user