Compare commits

...

6 Commits

Author SHA1 Message Date
CarolinePascal 0172b5bdaf feat(trim): adding optional trimming option in reencode_video 2026-06-11 21:40:50 +02:00
Pepijn 41166b39fb fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768)
* fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync

In distributed training, accelerate can only synchronize the shuffle
permutation across ranks when the sampler exposes a generator attribute.
EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch
shards relied on every rank's global CPU RNG staying in lockstep forever;
any rank-asymmetric RNG consumption (e.g. eval rollouts on the main
process only) silently desynced the permutations and ranks trained on
overlapping/missing samples.

* fix(train): seed sampler generator and gate dataset download per node

- Pass a generator seeded with cfg.seed to EpisodeAwareSampler so
  accelerator.prepare registers it as the synchronized RNG and the
  shuffle order is reproducible.
- Gate the initial make_dataset call on is_local_main_process instead of
  is_main_process: the global main process only exists on node 0, so on
  every other node all local ranks were downloading the dataset and
  building the Arrow cache concurrently.
2026-06-11 11:07:42 +02:00
Steven Palma 79c6821407 chore(dependecies): update mujoco transitives (#3756) 2026-06-10 12:58:55 +02:00
Steven Palma 507083249f Revert "fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751)" (#3754)
This reverts commit bd22407d93.
2026-06-10 10:38:42 +02:00
Caroline Pascal bd22407d93 fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751)
* fix(pyproject): adding ceiling bound on mujoco (<3.9.0)

* chore(uv.lock): updating uv.lock

* fix(linux): adding missing linux dependencies

* chore(uv.lock): updating uv.lock
2026-06-09 23:31:43 +02:00
Adil Zouitine 49755a3d9e feat(processor): Add in-memory processor pipeline serialization (#3732)
* feat(processor): add in-memory pipeline serialization

Expose processor pipeline config and tensor state without requiring temporary files, so processors can be transported, compared, or hashed directly in memory.

* feat(processor): enhance DataProcessorPipeline with registry support

- Added a new RegisteredLazyTensorStateStep for registry-based serialization tests.
- Improved state filename handling in _get_state_filename method.
- Refactored validation logic in _validate_loaded_config to simplify parameter types.
- Updated tests to verify registry step functionality and ensure correct state loading.

* refactor(processor): update state handling in DataProcessorPipeline

- Introduced a new static method _get_state_key to derive in-memory state keys from serialized filenames.
- Updated state_dict and load_state_dict methods to use suffixless state keys instead of filenames.
- Adjusted related tests to reflect changes in state key handling, ensuring consistency in state management

* fix(processor): update loaded_config argument description in DataProcessorPipeline

- Clarified the documentation for the loaded_config parameter to indicate that it may be a non-dictionary value, enhancing understanding for future developers.
2026-06-08 11:27:24 +02:00
8 changed files with 583 additions and 76 deletions
+3 -3
View File
@@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot
topreward = ["lerobot[transformers-dep]"]
xvla = ["lerobot[transformers-dep]"]
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
# Features
@@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
# Simulation
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
+7 -1
View File
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
for i in torch.randperm(len(self.indices), generator=self.generator):
yield self.indices[i]
else:
for i in self.indices:
+21 -1
View File
@@ -481,8 +481,10 @@ def reencode_video(
encoder_threads: int | None = None,
log_level: int | None = av.logging.WARNING,
overwrite: bool = False,
start_time_s: float | None = None,
end_time_s: float | None = None,
) -> None:
"""Re-encode a video file using the given encoder configuration.
"""Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``.
Args:
input_video_path: Existing video file to read.
@@ -491,10 +493,17 @@ def reencode_video(
encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`.
log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING.
overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning.
start_time_s: When set, trim the output to start at this timestamp (seconds).
end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive).
"""
camera_encoder = camera_encoder or camera_encoder_defaults()
if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0):
raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.")
if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s:
raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).")
output_video_path = Path(output_video_path)
if output_video_path.exists() and not overwrite:
@@ -526,6 +535,10 @@ def reencode_video(
width = int(in_stream.width)
height = int(in_stream.height)
# Seek to the keyframe at or before start_time_s to avoid reading from the start.
if start_time_s is not None:
src.seek(int(start_time_s * av.time_base), backward=True)
with av.open(
tmp_output_video_path,
mode="w",
@@ -539,7 +552,14 @@ def reencode_video(
out_stream.height = height
for frame in src.decode(in_stream):
frame_time_s = frame.time
if start_time_s is not None and frame_time_s < start_time_s:
continue
if end_time_s is not None and frame_time_s >= end_time_s:
break
frame = frame.reformat(width=width, height=height, format=pix_fmt)
if start_time_s is not None:
frame.pts = None # reset timestamps so the trimmed output starts at t=0
packet = out_stream.encode(frame)
if packet:
dst.mux(packet)
+279 -55
View File
@@ -32,7 +32,6 @@ from __future__ import annotations
import importlib
import json
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Sequence
@@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
_serialized_state_filenames: tuple[str | None, ...] | None = field(
default=None,
init=False,
repr=False,
)
def __call__(self, data: TInput) -> TOutput:
"""Processes input data through the full pipeline.
@@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
transition = processor_step(transition)
yield transition
def _save_pretrained(self, save_directory: Path, **kwargs):
"""Internal method to comply with `HubMixin`'s saving mechanism.
def _get_sanitized_name(self) -> str:
"""Return a filename-safe version of the pipeline name.
This method does the actual saving work and is called by HubMixin.save_pretrained.
Returns:
The lower-cased pipeline name with non-alphanumeric characters replaced by underscores.
"""
config_filename = kwargs.pop("config_filename", None)
return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
# Sanitize the pipeline name to create a valid filename prefix.
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower())
@staticmethod
def _get_state_filename(
*,
step_index: int,
registry_name: str | None,
sanitized_name: str,
) -> str:
"""Return the safetensors filename for one stateful processor step.
if config_filename is None:
config_filename = f"{sanitized_name}.json"
Args:
step_index: The index of the processor step in this pipeline.
registry_name: The registered processor step name, if available.
sanitized_name: The filename-safe pipeline name.
config: dict[str, Any] = {
Returns:
The state filename used by the existing disk serialization format.
"""
if registry_name:
return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
return f"{sanitized_name}_step_{step_index}.safetensors"
@staticmethod
def _get_state_key(state_filename: str) -> str:
"""Return the in-memory state key for a serialized state filename.
Args:
state_filename: The `.safetensors` filename from the serialized config.
Returns:
The state key used by the in-memory pipeline state dictionary.
"""
return state_filename.removesuffix(".safetensors")
@staticmethod
def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]:
"""Return serialized state filenames in step order.
Args:
loaded_config: A validated processor pipeline config.
Returns:
A tuple containing each step's serialized state filename, or None for stateless steps.
"""
return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"])
def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]:
"""Return expected state filenames in step order for `load_state_dict()`.
Returns:
The preserved serialized state filenames when available, otherwise filenames derived from
current non-empty step state.
"""
if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len(
self.steps
):
return self._serialized_state_filenames
sanitized_name = self._get_sanitized_name()
state_filenames: list[str | None] = []
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
state_filenames.append(None)
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filenames.append(
self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
)
return tuple(state_filenames)
def get_config(self) -> dict[str, Any]:
"""Return the JSON-serializable pipeline configuration.
Returns:
A dictionary with the same content that `save_pretrained()` writes as JSON.
"""
sanitized_name = self._get_sanitized_name()
pipeline_config: dict[str, Any] = {
"name": self.name,
"steps": [],
}
# Iterate through each step to build its configuration entry.
for step_index, processor_step in enumerate(self.steps):
registry_name = getattr(processor_step.__class__, "_registry_name", None)
step_entry: dict[str, Any] = {}
# Prefer registry name for portability, otherwise fall back to full class path.
if registry_name:
step_entry["registry_name"] = registry_name
else:
@@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}"
)
# Save step configuration if `get_config` is implemented.
if hasattr(processor_step, "get_config"):
step_entry["config"] = processor_step.get_config()
step_entry["config"] = processor_step.get_config()
# Save step state if `state_dict` is implemented and returns a non-empty dict.
if hasattr(processor_step, "state_dict"):
state = processor_step.state_dict()
if state:
# Clone tensors to avoid modifying the original state.
cloned_state = {key: tensor.clone() for key, tensor in state.items()}
step_state_dict = processor_step.state_dict()
if step_state_dict:
step_entry["state_file"] = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
# Create a unique filename for the state file.
if registry_name:
state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors"
else:
state_filename = f"{sanitized_name}_step_{step_index}.safetensors"
pipeline_config["steps"].append(step_entry)
save_file(cloned_state, os.path.join(str(save_directory), state_filename))
step_entry["state_file"] = state_filename
return pipeline_config
config["steps"].append(step_entry)
def state_dict(self) -> dict[str, dict[str, torch.Tensor]]:
"""Return pipeline state tensors grouped by state key.
# Write the main configuration JSON file.
with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer:
json.dump(config, file_pointer, indent=2)
Returns:
A dictionary mapping suffixless state keys to cloned step state dictionaries.
"""
sanitized_name = self._get_sanitized_name()
pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for step_index, processor_step in enumerate(self.steps):
step_state_dict = processor_step.state_dict()
if not step_state_dict:
continue
registry_name = getattr(processor_step.__class__, "_registry_name", None)
state_filename = self._get_state_filename(
step_index=step_index,
registry_name=registry_name,
sanitized_name=sanitized_name,
)
state_key = self._get_state_key(state_filename)
pipeline_state_dict[state_key] = {
tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items()
}
return pipeline_state_dict
def load_state_dict(
self,
state_dict: dict[str, dict[str, torch.Tensor]],
) -> None:
"""Load pipeline state tensors into the existing steps.
Args:
state_dict: A dictionary mapping suffixless state keys to step state dictionaries.
Raises:
KeyError: If loading finds missing expected state or unexpected extra state.
"""
expected_state_filenames = self._get_state_filenames_for_loading()
used_state_keys: set[str] = set()
for step_index, (processor_step, state_filename) in enumerate(
zip(self.steps, expected_state_filenames, strict=True)
):
if state_filename is None:
continue
state_key = self._get_state_key(state_filename)
if state_key not in state_dict:
raise KeyError(
f"Missing state key '{state_key}' for processor step {step_index}. "
f"Available state keys: {sorted(state_dict.keys())}"
)
processor_step.load_state_dict(state_dict[state_key])
used_state_keys.add(state_key)
unexpected_state_keys = set(state_dict) - used_state_keys
if unexpected_state_keys:
expected_state_key_set = {
self._get_state_key(state_filename)
for state_filename in expected_state_filenames
if state_filename is not None
}
raise KeyError(
f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. "
f"Expected state keys: {sorted(expected_state_key_set)}"
)
def _save_pretrained(self, save_directory: Path, **kwargs) -> None:
"""Internal method to comply with `HubMixin`'s saving mechanism.
This method does the actual saving work and is called by HubMixin.save_pretrained.
"""
config_filename = kwargs.pop("config_filename", None)
sanitized_name = self._get_sanitized_name()
if config_filename is None:
config_filename = f"{sanitized_name}.json"
pipeline_config = self.get_config()
pipeline_state_dict = self.state_dict()
for state_key, step_state_dict in pipeline_state_dict.items():
state_filename = f"{state_key}.safetensors"
save_file(step_state_dict, save_directory / state_filename)
with open(save_directory / config_filename, "w") as file_pointer:
json.dump(pipeline_config, file_pointer, indent=2)
def save_pretrained(
self,
@@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
cls._validate_overrides_used(validated_overrides, loaded_config)
# 5. Construct and return the final pipeline instance
return cls(
pipeline = cls(
steps=steps,
name=loaded_config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config)
return pipeline
@classmethod
def from_config(
cls,
config: dict[str, Any],
*,
state_dict: dict[str, dict[str, torch.Tensor]] | None = None,
overrides: dict[str, Any] | None = None,
to_transition: Callable[[TInput], EnvTransition] | None = None,
to_output: Callable[[EnvTransition], TOutput] | None = None,
) -> DataProcessorPipeline[TInput, TOutput]:
"""Build a pipeline from an in-memory config and optional state tensors.
Args:
config: A config dictionary with the same structure as the saved processor JSON.
state_dict: Optional in-memory pipeline state grouped by suffixless state key.
overrides: Optional constructor overrides keyed by registry name or class name.
to_transition: Optional converter from input data to `EnvTransition`.
to_output: Optional converter from `EnvTransition` to output data.
Returns:
A processor pipeline built from the config and optional state.
"""
cls._validate_loaded_config("<in-memory config>", config, "<in-memory config>")
steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {})
cls._validate_overrides_used(remaining_override_keys, config)
pipeline = cls(
steps=steps,
name=config.get("name", "DataProcessorPipeline"),
to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition),
to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch),
)
pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config)
if state_dict is not None:
pipeline.load_state_dict(state_dict)
return pipeline
@classmethod
def _load_config(
@@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
) from e
@classmethod
def _validate_loaded_config(
cls, model_id: str, loaded_config: dict[str, Any], config_filename: str
) -> None:
def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None:
"""Validate that a config was loaded and is a valid processor config.
This method validates processor config format with intelligent migration detection:
@@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Args:
model_id: The model identifier (used for migration detection)
loaded_config: The loaded config dictionary (guaranteed non-None)
loaded_config: The loaded config value to validate (may be non-dict)
config_filename: The config filename that was loaded (for error messages)
Raises:
@@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
model_id,
f"Config file '{config_filename}' is not a valid processor configuration",
)
loaded_config_description = (
list(loaded_config.keys())
if isinstance(loaded_config, dict)
else type(loaded_config).__name__
)
raise ValueError(
f"Config file '{config_filename}' is not a valid processor configuration. "
f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}"
f"Expected a config with 'steps' field, but got: {loaded_config_description}"
)
@classmethod
@@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
ImportError: If a step class cannot be imported or found in registry
ValueError: If a step cannot be instantiated with its configuration
"""
steps: list[ProcessorStep] = []
override_keys = set(overrides.keys())
steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides)
for step_entry in loaded_config["steps"]:
# 1. Get step class and key
step_class, step_key = cls._resolve_step_class(step_entry)
# 2. Instantiate step with overrides
step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides)
# 3. Load step state if available
for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True):
cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs)
# 4. Track used overrides
if step_key in override_keys:
override_keys.discard(step_key)
return steps, remaining_override_keys
steps.append(step_instance)
@classmethod
def _build_steps_from_config(
cls,
loaded_config: dict[str, Any],
overrides: dict[str, Any],
) -> tuple[list[ProcessorStep], set[str]]:
"""Build processor steps from config without loading tensor state.
return steps, override_keys
Args:
loaded_config: The loaded processor configuration.
overrides: User-provided constructor overrides keyed by step key.
Returns:
A tuple containing instantiated steps and override keys that did not match a step.
"""
processor_steps: list[ProcessorStep] = []
remaining_override_keys = set(overrides.keys())
for step_entry in loaded_config["steps"]:
step_class, step_key = cls._resolve_step_class(step_entry)
processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides)
if step_key in remaining_override_keys:
remaining_override_keys.discard(step_key)
processor_steps.append(processor_step)
return processor_steps, remaining_override_keys
@classmethod
def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]:
@@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
return True
@classmethod
def _is_processor_config(cls, config: dict) -> bool:
def _is_processor_config(cls, config: Any) -> bool:
"""Check if config follows DataProcessorPipeline format.
This method validates the processor configuration structure:
@@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin):
Returns:
True if config follows valid DataProcessorPipeline format, False otherwise
"""
if not isinstance(config, dict):
return False
# Must have a "steps" field with a list of step configurations
if not isinstance(config.get("steps"), list):
return False
+15 -5
View File
@@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
+24
View File
@@ -114,6 +114,30 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+220
View File
@@ -24,6 +24,7 @@ from typing import Any
import pytest
import torch
import torch.nn as nn
from safetensors.torch import load_file
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
@@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep):
return features
class MockLazyTensorStateStep(ProcessorStep):
"""Mock step whose tensor state is not present in constructor config."""
def __init__(
self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None
):
self.name = name
self.scale = scale
self.tensor_state: torch.Tensor | None = None
if initial_value is not None:
self.tensor_state = torch.tensor([initial_value], dtype=torch.float32)
def __call__(self, transition: EnvTransition) -> EnvTransition:
"""Return the transition unchanged."""
return transition
def get_config(self) -> dict[str, Any]:
"""Return constructor config while intentionally omitting tensor state."""
return {
"name": self.name,
"scale": self.scale,
}
def state_dict(self) -> dict[str, torch.Tensor]:
"""Return tensor state only after it has been initialized or loaded."""
if self.tensor_state is None:
return {}
return {"tensor_state": self.tensor_state}
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
"""Load tensor state."""
self.tensor_state = state["tensor_state"].clone()
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""Return features unchanged."""
return features
@ProcessorStepRegistry.register("registered_lazy_tensor_state_step")
class RegisteredLazyTensorStateStep(MockLazyTensorStateStep):
"""Registered lazy tensor state step for registry-based serialization tests."""
def test_empty_pipeline():
"""Test pipeline with no steps."""
pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition)
@@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state():
assert torch.allclose(loaded_step.running_mean, step.running_mean)
def test_get_config_matches_saved_json():
"""Test that in-memory config matches the config written by save_pretrained."""
stateless_step = MockStep(name="stateless")
stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0)
pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline")
in_memory_config = pipeline.get_config()
assert pipeline.get_config() == in_memory_config
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
config_path = Path(tmp_dir) / "memory_pipeline.json"
with open(config_path) as file_pointer:
saved_config = json.load(file_pointer)
assert in_memory_config == saved_config
assert "state_file" not in in_memory_config["steps"][0]
assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors"
def test_state_dict_matches_saved_safetensors():
"""Test that in-memory state matches the safetensors written by save_pretrained."""
stateful_step = MockLazyTensorStateStep(initial_value=7.0)
pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline")
in_memory_state_dict = pipeline.state_dict()
state_filename = "stateful_pipeline_step_0.safetensors"
state_key = "stateful_pipeline_step_0"
assert set(in_memory_state_dict) == {state_key}
assert set(in_memory_state_dict[state_key]) == {"tensor_state"}
in_memory_state_dict[state_key]["tensor_state"].add_(1)
assert stateful_step.tensor_state is not None
assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0]))
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
saved_state_dict = load_file(Path(tmp_dir) / state_filename)
torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0]))
def test_save_pretrained_still_writes_expected_serialization_files():
"""Test that save_pretrained keeps the existing config and state filenames."""
stateful_step = MockLazyTensorStateStep(initial_value=3.0)
pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor")
with tempfile.TemporaryDirectory() as tmp_dir:
pipeline.save_pretrained(tmp_dir)
save_path = Path(tmp_dir)
assert (save_path / "policy_preprocessor.json").exists()
assert (save_path / "policy_preprocessor_step_0.safetensors").exists()
def test_from_config_round_trips_stateful_pipeline():
"""Test that from_config rebuilds a stateful pipeline from in-memory artifacts."""
stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0)
pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert len(loaded_pipeline) == 1
assert isinstance(loaded_step, MockLazyTensorStateStep)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0]))
def test_from_config_round_trips_registered_stateful_pipeline():
"""Test that from_config resolves registry steps and loads their named tensor state."""
stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0)
pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors"
state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step"
assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step"
assert config["steps"][0]["state_file"] == state_filename
assert set(pipeline_state_dict) == {state_key}
loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, RegisteredLazyTensorStateStep)
assert loaded_step.tensor_state is not None
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0]))
def test_from_config_preserves_state_metadata_for_empty_initial_state():
"""Test in-memory loading when rebuilt steps start without tensor state."""
stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0)
pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(config)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.state_dict() == {}
assert "state_file" not in loaded_pipeline.get_config()["steps"][0]
loaded_pipeline.load_state_dict(pipeline_state_dict)
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0]))
def test_from_config_applies_overrides_before_state_loading():
"""Test that constructor overrides and tensor state loading are separate operations."""
stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0)
pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline")
config = pipeline.get_config()
pipeline_state_dict = pipeline.state_dict()
loaded_pipeline = DataProcessorPipeline.from_config(
config,
state_dict=pipeline_state_dict,
overrides={"MockLazyTensorStateStep": {"scale": 5.0}},
)
loaded_step = loaded_pipeline.steps[0]
assert isinstance(loaded_step, MockLazyTensorStateStep)
assert loaded_step.scale == 5.0
torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0]))
def test_load_state_dict_raises_on_missing_expected_state():
"""Test loading raises when serialized config expects missing state."""
stateful_step = MockLazyTensorStateStep(initial_value=19.0)
pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline")
loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config())
with pytest.raises(KeyError, match="missing_pipeline_step_0"):
loaded_pipeline.load_state_dict({})
def test_load_state_dict_raises_on_unexpected_extra_state():
"""Test loading raises on unexpected top-level state keys."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline")
with pytest.raises(KeyError, match="extra"):
pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}})
def test_stateless_pipeline_in_memory_serialization_returns_empty_state():
"""Test stateless in-memory serialization and loading."""
pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline")
config = pipeline.get_config()
config_without_name = {"steps": config["steps"]}
assert pipeline.state_dict() == {}
assert all("state_file" not in step_entry for step_entry in config["steps"])
loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={})
assert loaded_pipeline.name == "DataProcessorPipeline"
assert loaded_pipeline.state_dict() == {}
@pytest.mark.parametrize("invalid_config", [None, [], "not config"])
def test_from_config_rejects_non_dict_config(invalid_config):
"""Test from_config reports invalid top-level config values cleanly."""
with pytest.raises(ValueError, match="not a valid processor configuration"):
DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type]
class MockModuleStep(ProcessorStep, nn.Module):
"""Mock step that inherits from nn.Module to test state_dict handling of module parameters."""
Generated
+14 -11
View File
@@ -1764,7 +1764,7 @@ wheels = [
[[package]]
name = "gym-aloha"
version = "0.1.3"
version = "0.1.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "dm-control" },
@@ -1772,14 +1772,14 @@ dependencies = [
{ name = "imageio", extra = ["ffmpeg"] },
{ name = "mujoco" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" }
sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" },
{ url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" },
]
[[package]]
name = "gym-hil"
version = "0.1.13"
version = "0.1.14"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "gymnasium" },
@@ -1789,9 +1789,9 @@ dependencies = [
{ name = "pygame" },
{ name = "pynput" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" }
sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" },
{ url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" },
]
[[package]]
@@ -1881,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9
[[package]]
name = "hf-libero"
version = "0.1.3"
version = "0.1.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "bddl", marker = "sys_platform == 'linux'" },
@@ -1902,7 +1902,10 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'linux'" },
{ name = "wandb", marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" }
sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" },
]
[[package]]
name = "hf-xet"
@@ -3090,12 +3093,12 @@ requires-dist = [
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" },
{ name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" },
{ name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" },
{ name = "gymnasium", specifier = ">=1.1.1,<2.0.0" },
{ name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" },
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" },
{ name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" },
{ name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" },
{ name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" },
{ name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" },