From 49755a3d9e7d43ae93092de8324e75348955afab Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 8 Jun 2026 11:27:24 +0200 Subject: [PATCH 01/27] feat(processor): Add in-memory processor pipeline serialization (#3732) * feat(processor): add in-memory pipeline serialization Expose processor pipeline config and tensor state without requiring temporary files, so processors can be transported, compared, or hashed directly in memory. * feat(processor): enhance DataProcessorPipeline with registry support - Added a new RegisteredLazyTensorStateStep for registry-based serialization tests. - Improved state filename handling in _get_state_filename method. - Refactored validation logic in _validate_loaded_config to simplify parameter types. - Updated tests to verify registry step functionality and ensure correct state loading. * refactor(processor): update state handling in DataProcessorPipeline - Introduced a new static method _get_state_key to derive in-memory state keys from serialized filenames. - Updated state_dict and load_state_dict methods to use suffixless state keys instead of filenames. - Adjusted related tests to reflect changes in state key handling, ensuring consistency in state management * fix(processor): update loaded_config argument description in DataProcessorPipeline - Clarified the documentation for the loaded_config parameter to indicate that it may be a non-dictionary value, enhancing understanding for future developers. --- src/lerobot/processor/pipeline.py | 334 +++++++++++++++++++++++++----- tests/processor/test_pipeline.py | 220 ++++++++++++++++++++ 2 files changed, 499 insertions(+), 55 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 2b949d5cb..b9b9c6c43 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -32,7 +32,6 @@ from __future__ import annotations import importlib import json -import os import re from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Sequence @@ -281,6 +280,11 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + _serialized_state_filenames: tuple[str | None, ...] | None = field( + default=None, + init=False, + repr=False, + ) def __call__(self, data: TInput) -> TOutput: """Processes input data through the full pipeline. @@ -338,30 +342,108 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): transition = processor_step(transition) yield transition - def _save_pretrained(self, save_directory: Path, **kwargs): - """Internal method to comply with `HubMixin`'s saving mechanism. + def _get_sanitized_name(self) -> str: + """Return a filename-safe version of the pipeline name. - This method does the actual saving work and is called by HubMixin.save_pretrained. + Returns: + The lower-cased pipeline name with non-alphanumeric characters replaced by underscores. """ - config_filename = kwargs.pop("config_filename", None) + return re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) - # Sanitize the pipeline name to create a valid filename prefix. - sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + @staticmethod + def _get_state_filename( + *, + step_index: int, + registry_name: str | None, + sanitized_name: str, + ) -> str: + """Return the safetensors filename for one stateful processor step. - if config_filename is None: - config_filename = f"{sanitized_name}.json" + Args: + step_index: The index of the processor step in this pipeline. + registry_name: The registered processor step name, if available. + sanitized_name: The filename-safe pipeline name. - config: dict[str, Any] = { + Returns: + The state filename used by the existing disk serialization format. + """ + if registry_name: + return f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" + + return f"{sanitized_name}_step_{step_index}.safetensors" + + @staticmethod + def _get_state_key(state_filename: str) -> str: + """Return the in-memory state key for a serialized state filename. + + Args: + state_filename: The `.safetensors` filename from the serialized config. + + Returns: + The state key used by the in-memory pipeline state dictionary. + """ + return state_filename.removesuffix(".safetensors") + + @staticmethod + def _get_state_filenames_from_config(loaded_config: dict[str, Any]) -> tuple[str | None, ...]: + """Return serialized state filenames in step order. + + Args: + loaded_config: A validated processor pipeline config. + + Returns: + A tuple containing each step's serialized state filename, or None for stateless steps. + """ + return tuple(step_entry.get("state_file") for step_entry in loaded_config["steps"]) + + def _get_state_filenames_for_loading(self) -> tuple[str | None, ...]: + """Return expected state filenames in step order for `load_state_dict()`. + + Returns: + The preserved serialized state filenames when available, otherwise filenames derived from + current non-empty step state. + """ + if self._serialized_state_filenames is not None and len(self._serialized_state_filenames) == len( + self.steps + ): + return self._serialized_state_filenames + + sanitized_name = self._get_sanitized_name() + state_filenames: list[str | None] = [] + + for step_index, processor_step in enumerate(self.steps): + step_state_dict = processor_step.state_dict() + if not step_state_dict: + state_filenames.append(None) + continue + + registry_name = getattr(processor_step.__class__, "_registry_name", None) + state_filenames.append( + self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) + ) + + return tuple(state_filenames) + + def get_config(self) -> dict[str, Any]: + """Return the JSON-serializable pipeline configuration. + + Returns: + A dictionary with the same content that `save_pretrained()` writes as JSON. + """ + sanitized_name = self._get_sanitized_name() + pipeline_config: dict[str, Any] = { "name": self.name, "steps": [], } - # Iterate through each step to build its configuration entry. for step_index, processor_step in enumerate(self.steps): registry_name = getattr(processor_step.__class__, "_registry_name", None) - step_entry: dict[str, Any] = {} - # Prefer registry name for portability, otherwise fall back to full class path. + if registry_name: step_entry["registry_name"] = registry_name else: @@ -369,31 +451,110 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" ) - # Save step configuration if `get_config` is implemented. - if hasattr(processor_step, "get_config"): - step_entry["config"] = processor_step.get_config() + step_entry["config"] = processor_step.get_config() - # Save step state if `state_dict` is implemented and returns a non-empty dict. - if hasattr(processor_step, "state_dict"): - state = processor_step.state_dict() - if state: - # Clone tensors to avoid modifying the original state. - cloned_state = {key: tensor.clone() for key, tensor in state.items()} + step_state_dict = processor_step.state_dict() + if step_state_dict: + step_entry["state_file"] = self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) - # Create a unique filename for the state file. - if registry_name: - state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" - else: - state_filename = f"{sanitized_name}_step_{step_index}.safetensors" + pipeline_config["steps"].append(step_entry) - save_file(cloned_state, os.path.join(str(save_directory), state_filename)) - step_entry["state_file"] = state_filename + return pipeline_config - config["steps"].append(step_entry) + def state_dict(self) -> dict[str, dict[str, torch.Tensor]]: + """Return pipeline state tensors grouped by state key. - # Write the main configuration JSON file. - with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: - json.dump(config, file_pointer, indent=2) + Returns: + A dictionary mapping suffixless state keys to cloned step state dictionaries. + """ + sanitized_name = self._get_sanitized_name() + pipeline_state_dict: dict[str, dict[str, torch.Tensor]] = {} + + for step_index, processor_step in enumerate(self.steps): + step_state_dict = processor_step.state_dict() + if not step_state_dict: + continue + + registry_name = getattr(processor_step.__class__, "_registry_name", None) + state_filename = self._get_state_filename( + step_index=step_index, + registry_name=registry_name, + sanitized_name=sanitized_name, + ) + state_key = self._get_state_key(state_filename) + pipeline_state_dict[state_key] = { + tensor_name: tensor.clone() for tensor_name, tensor in step_state_dict.items() + } + + return pipeline_state_dict + + def load_state_dict( + self, + state_dict: dict[str, dict[str, torch.Tensor]], + ) -> None: + """Load pipeline state tensors into the existing steps. + + Args: + state_dict: A dictionary mapping suffixless state keys to step state dictionaries. + + Raises: + KeyError: If loading finds missing expected state or unexpected extra state. + """ + expected_state_filenames = self._get_state_filenames_for_loading() + used_state_keys: set[str] = set() + + for step_index, (processor_step, state_filename) in enumerate( + zip(self.steps, expected_state_filenames, strict=True) + ): + if state_filename is None: + continue + + state_key = self._get_state_key(state_filename) + if state_key not in state_dict: + raise KeyError( + f"Missing state key '{state_key}' for processor step {step_index}. " + f"Available state keys: {sorted(state_dict.keys())}" + ) + + processor_step.load_state_dict(state_dict[state_key]) + used_state_keys.add(state_key) + + unexpected_state_keys = set(state_dict) - used_state_keys + if unexpected_state_keys: + expected_state_key_set = { + self._get_state_key(state_filename) + for state_filename in expected_state_filenames + if state_filename is not None + } + raise KeyError( + f"Unexpected processor state keys: {sorted(unexpected_state_keys)}. " + f"Expected state keys: {sorted(expected_state_key_set)}" + ) + + def _save_pretrained(self, save_directory: Path, **kwargs) -> None: + """Internal method to comply with `HubMixin`'s saving mechanism. + + This method does the actual saving work and is called by HubMixin.save_pretrained. + """ + config_filename = kwargs.pop("config_filename", None) + sanitized_name = self._get_sanitized_name() + + if config_filename is None: + config_filename = f"{sanitized_name}.json" + + pipeline_config = self.get_config() + pipeline_state_dict = self.state_dict() + + for state_key, step_state_dict in pipeline_state_dict.items(): + state_filename = f"{state_key}.safetensors" + save_file(step_state_dict, save_directory / state_filename) + + with open(save_directory / config_filename, "w") as file_pointer: + json.dump(pipeline_config, file_pointer, indent=2) def save_pretrained( self, @@ -577,12 +738,54 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): cls._validate_overrides_used(validated_overrides, loaded_config) # 5. Construct and return the final pipeline instance - return cls( + pipeline = cls( steps=steps, name=loaded_config.get("name", "DataProcessorPipeline"), to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), ) + pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(loaded_config) + return pipeline + + @classmethod + def from_config( + cls, + config: dict[str, Any], + *, + state_dict: dict[str, dict[str, torch.Tensor]] | None = None, + overrides: dict[str, Any] | None = None, + to_transition: Callable[[TInput], EnvTransition] | None = None, + to_output: Callable[[EnvTransition], TOutput] | None = None, + ) -> DataProcessorPipeline[TInput, TOutput]: + """Build a pipeline from an in-memory config and optional state tensors. + + Args: + config: A config dictionary with the same structure as the saved processor JSON. + state_dict: Optional in-memory pipeline state grouped by suffixless state key. + overrides: Optional constructor overrides keyed by registry name or class name. + to_transition: Optional converter from input data to `EnvTransition`. + to_output: Optional converter from `EnvTransition` to output data. + + Returns: + A processor pipeline built from the config and optional state. + """ + cls._validate_loaded_config("", config, "") + + steps, remaining_override_keys = cls._build_steps_from_config(config, overrides or {}) + cls._validate_overrides_used(remaining_override_keys, config) + + pipeline = cls( + steps=steps, + name=config.get("name", "DataProcessorPipeline"), + to_transition=to_transition or cast(Callable[[TInput], EnvTransition], batch_to_transition), + to_output=to_output or cast(Callable[[EnvTransition], TOutput], transition_to_batch), + ) + pipeline._serialized_state_filenames = cls._get_state_filenames_from_config(config) + + if state_dict is not None: + pipeline.load_state_dict(state_dict) + + return pipeline @classmethod def _load_config( @@ -666,9 +869,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): ) from e @classmethod - def _validate_loaded_config( - cls, model_id: str, loaded_config: dict[str, Any], config_filename: str - ) -> None: + def _validate_loaded_config(cls, model_id: str, loaded_config: Any, config_filename: str) -> None: """Validate that a config was loaded and is a valid processor config. This method validates processor config format with intelligent migration detection: @@ -688,7 +889,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): Args: model_id: The model identifier (used for migration detection) - loaded_config: The loaded config dictionary (guaranteed non-None) + loaded_config: The loaded config value to validate (may be non-dict) config_filename: The config filename that was loaded (for error messages) Raises: @@ -702,9 +903,14 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): model_id, f"Config file '{config_filename}' is not a valid processor configuration", ) + loaded_config_description = ( + list(loaded_config.keys()) + if isinstance(loaded_config, dict) + else type(loaded_config).__name__ + ) raise ValueError( f"Config file '{config_filename}' is not a valid processor configuration. " - f"Expected a config with 'steps' field, but got: {list(loaded_config.keys())}" + f"Expected a config with 'steps' field, but got: {loaded_config_description}" ) @classmethod @@ -766,26 +972,41 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): ImportError: If a step class cannot be imported or found in registry ValueError: If a step cannot be instantiated with its configuration """ - steps: list[ProcessorStep] = [] - override_keys = set(overrides.keys()) + steps, remaining_override_keys = cls._build_steps_from_config(loaded_config, overrides) - for step_entry in loaded_config["steps"]: - # 1. Get step class and key - step_class, step_key = cls._resolve_step_class(step_entry) - - # 2. Instantiate step with overrides - step_instance = cls._instantiate_step(step_entry, step_class, step_key, overrides) - - # 3. Load step state if available + for step_instance, step_entry in zip(steps, loaded_config["steps"], strict=True): cls._load_step_state(step_instance, step_entry, model_id, base_path, hub_download_kwargs) - # 4. Track used overrides - if step_key in override_keys: - override_keys.discard(step_key) + return steps, remaining_override_keys - steps.append(step_instance) + @classmethod + def _build_steps_from_config( + cls, + loaded_config: dict[str, Any], + overrides: dict[str, Any], + ) -> tuple[list[ProcessorStep], set[str]]: + """Build processor steps from config without loading tensor state. - return steps, override_keys + Args: + loaded_config: The loaded processor configuration. + overrides: User-provided constructor overrides keyed by step key. + + Returns: + A tuple containing instantiated steps and override keys that did not match a step. + """ + processor_steps: list[ProcessorStep] = [] + remaining_override_keys = set(overrides.keys()) + + for step_entry in loaded_config["steps"]: + step_class, step_key = cls._resolve_step_class(step_entry) + processor_step = cls._instantiate_step(step_entry, step_class, step_key, overrides) + + if step_key in remaining_override_keys: + remaining_override_keys.discard(step_key) + + processor_steps.append(processor_step) + + return processor_steps, remaining_override_keys @classmethod def _resolve_step_class(cls, step_entry: dict[str, Any]) -> tuple[type[ProcessorStep], str]: @@ -1096,7 +1317,7 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): return True @classmethod - def _is_processor_config(cls, config: dict) -> bool: + def _is_processor_config(cls, config: Any) -> bool: """Check if config follows DataProcessorPipeline format. This method validates the processor configuration structure: @@ -1147,6 +1368,9 @@ class DataProcessorPipeline[TInput, TOutput](HubMixin): Returns: True if config follows valid DataProcessorPipeline format, False otherwise """ + if not isinstance(config, dict): + return False + # Must have a "steps" field with a list of step configurations if not isinstance(config.get("steps"), list): return False diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 2c41de22c..57e948279 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -24,6 +24,7 @@ from typing import Any import pytest import torch import torch.nn as nn +from safetensors.torch import load_file pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") @@ -174,6 +175,53 @@ class MockStepWithTensorState(ProcessorStep): return features +class MockLazyTensorStateStep(ProcessorStep): + """Mock step whose tensor state is not present in constructor config.""" + + def __init__( + self, name: str = "lazy_tensor_step", scale: float = 1.0, initial_value: float | None = None + ): + self.name = name + self.scale = scale + self.tensor_state: torch.Tensor | None = None + + if initial_value is not None: + self.tensor_state = torch.tensor([initial_value], dtype=torch.float32) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Return the transition unchanged.""" + return transition + + def get_config(self) -> dict[str, Any]: + """Return constructor config while intentionally omitting tensor state.""" + return { + "name": self.name, + "scale": self.scale, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return tensor state only after it has been initialized or loaded.""" + if self.tensor_state is None: + return {} + + return {"tensor_state": self.tensor_state} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load tensor state.""" + self.tensor_state = state["tensor_state"].clone() + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Return features unchanged.""" + return features + + +@ProcessorStepRegistry.register("registered_lazy_tensor_state_step") +class RegisteredLazyTensorStateStep(MockLazyTensorStateStep): + """Registered lazy tensor state step for registry-based serialization tests.""" + + def test_empty_pipeline(): """Test pipeline with no steps.""" pipeline = DataProcessorPipeline([], to_transition=identity_transition, to_output=identity_transition) @@ -620,6 +668,178 @@ def test_mixed_json_and_tensor_state(): assert torch.allclose(loaded_step.running_mean, step.running_mean) +def test_get_config_matches_saved_json(): + """Test that in-memory config matches the config written by save_pretrained.""" + stateless_step = MockStep(name="stateless") + stateful_step = MockLazyTensorStateStep(name="stateful", initial_value=4.0) + pipeline = DataProcessorPipeline([stateless_step, stateful_step], name="Memory Pipeline") + + in_memory_config = pipeline.get_config() + + assert pipeline.get_config() == in_memory_config + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + config_path = Path(tmp_dir) / "memory_pipeline.json" + with open(config_path) as file_pointer: + saved_config = json.load(file_pointer) + + assert in_memory_config == saved_config + assert "state_file" not in in_memory_config["steps"][0] + assert in_memory_config["steps"][1]["state_file"] == "memory_pipeline_step_1.safetensors" + + +def test_state_dict_matches_saved_safetensors(): + """Test that in-memory state matches the safetensors written by save_pretrained.""" + stateful_step = MockLazyTensorStateStep(initial_value=7.0) + pipeline = DataProcessorPipeline([stateful_step], name="Stateful Pipeline") + + in_memory_state_dict = pipeline.state_dict() + state_filename = "stateful_pipeline_step_0.safetensors" + state_key = "stateful_pipeline_step_0" + + assert set(in_memory_state_dict) == {state_key} + assert set(in_memory_state_dict[state_key]) == {"tensor_state"} + + in_memory_state_dict[state_key]["tensor_state"].add_(1) + assert stateful_step.tensor_state is not None + assert torch.equal(stateful_step.tensor_state, torch.tensor([7.0])) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + saved_state_dict = load_file(Path(tmp_dir) / state_filename) + + torch.testing.assert_close(saved_state_dict["tensor_state"], torch.tensor([7.0])) + + +def test_save_pretrained_still_writes_expected_serialization_files(): + """Test that save_pretrained keeps the existing config and state filenames.""" + stateful_step = MockLazyTensorStateStep(initial_value=3.0) + pipeline = DataProcessorPipeline([stateful_step], name="Policy Preprocessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + save_path = Path(tmp_dir) + assert (save_path / "policy_preprocessor.json").exists() + assert (save_path / "policy_preprocessor_step_0.safetensors").exists() + + +def test_from_config_round_trips_stateful_pipeline(): + """Test that from_config rebuilds a stateful pipeline from in-memory artifacts.""" + stateful_step = MockLazyTensorStateStep(name="roundtrip", initial_value=11.0) + pipeline = DataProcessorPipeline([stateful_step], name="Roundtrip Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict) + loaded_step = loaded_pipeline.steps[0] + + assert len(loaded_pipeline) == 1 + assert isinstance(loaded_step, MockLazyTensorStateStep) + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([11.0])) + + +def test_from_config_round_trips_registered_stateful_pipeline(): + """Test that from_config resolves registry steps and loads their named tensor state.""" + stateful_step = RegisteredLazyTensorStateStep(name="registered", initial_value=29.0) + pipeline = DataProcessorPipeline([stateful_step], name="Registry Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + state_filename = "registry_pipeline_step_0_registered_lazy_tensor_state_step.safetensors" + state_key = "registry_pipeline_step_0_registered_lazy_tensor_state_step" + + assert config["steps"][0]["registry_name"] == "registered_lazy_tensor_state_step" + assert config["steps"][0]["state_file"] == state_filename + assert set(pipeline_state_dict) == {state_key} + + loaded_pipeline = DataProcessorPipeline.from_config(config, state_dict=pipeline_state_dict) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, RegisteredLazyTensorStateStep) + assert loaded_step.tensor_state is not None + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([29.0])) + + +def test_from_config_preserves_state_metadata_for_empty_initial_state(): + """Test in-memory loading when rebuilt steps start without tensor state.""" + stateful_step = MockLazyTensorStateStep(name="lazy", initial_value=13.0) + pipeline = DataProcessorPipeline([stateful_step], name="Lazy Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config(config) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, MockLazyTensorStateStep) + assert loaded_step.state_dict() == {} + assert "state_file" not in loaded_pipeline.get_config()["steps"][0] + + loaded_pipeline.load_state_dict(pipeline_state_dict) + + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([13.0])) + + +def test_from_config_applies_overrides_before_state_loading(): + """Test that constructor overrides and tensor state loading are separate operations.""" + stateful_step = MockLazyTensorStateStep(name="override", scale=1.0, initial_value=17.0) + pipeline = DataProcessorPipeline([stateful_step], name="Override Pipeline") + config = pipeline.get_config() + pipeline_state_dict = pipeline.state_dict() + + loaded_pipeline = DataProcessorPipeline.from_config( + config, + state_dict=pipeline_state_dict, + overrides={"MockLazyTensorStateStep": {"scale": 5.0}}, + ) + loaded_step = loaded_pipeline.steps[0] + + assert isinstance(loaded_step, MockLazyTensorStateStep) + assert loaded_step.scale == 5.0 + torch.testing.assert_close(loaded_step.tensor_state, torch.tensor([17.0])) + + +def test_load_state_dict_raises_on_missing_expected_state(): + """Test loading raises when serialized config expects missing state.""" + stateful_step = MockLazyTensorStateStep(initial_value=19.0) + pipeline = DataProcessorPipeline([stateful_step], name="Missing Pipeline") + loaded_pipeline = DataProcessorPipeline.from_config(pipeline.get_config()) + + with pytest.raises(KeyError, match="missing_pipeline_step_0"): + loaded_pipeline.load_state_dict({}) + + +def test_load_state_dict_raises_on_unexpected_extra_state(): + """Test loading raises on unexpected top-level state keys.""" + pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Unexpected Pipeline") + + with pytest.raises(KeyError, match="extra"): + pipeline.load_state_dict({"extra": {"tensor_state": torch.tensor([1.0])}}) + + +def test_stateless_pipeline_in_memory_serialization_returns_empty_state(): + """Test stateless in-memory serialization and loading.""" + pipeline = DataProcessorPipeline([MockStep(name="stateless")], name="Stateless Pipeline") + config = pipeline.get_config() + config_without_name = {"steps": config["steps"]} + + assert pipeline.state_dict() == {} + assert all("state_file" not in step_entry for step_entry in config["steps"]) + + loaded_pipeline = DataProcessorPipeline.from_config(config_without_name, state_dict={}) + + assert loaded_pipeline.name == "DataProcessorPipeline" + assert loaded_pipeline.state_dict() == {} + + +@pytest.mark.parametrize("invalid_config", [None, [], "not config"]) +def test_from_config_rejects_non_dict_config(invalid_config): + """Test from_config reports invalid top-level config values cleanly.""" + with pytest.raises(ValueError, match="not a valid processor configuration"): + DataProcessorPipeline.from_config(invalid_config) # type: ignore[arg-type] + + class MockModuleStep(ProcessorStep, nn.Module): """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" From bd22407d9390f6973868fed9a883b4f295db105e Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 9 Jun 2026 23:31:43 +0200 Subject: [PATCH 02/27] fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751) * fix(pyproject): adding ceiling bound on mujoco (<3.9.0) * chore(uv.lock): updating uv.lock * fix(linux): adding missing linux dependencies * chore(uv.lock): updating uv.lock --- pyproject.toml | 9 +++++---- uv.lock | 15 ++++++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b4c22f12..9690a0d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot topreward = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]", "lerobot[mujoco-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] # Features @@ -231,10 +231,11 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation # NOTE: Explicitly listing scipy helps flatten the dependecy tree. -aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] +mujoco-dep = ["mujoco<3.9.0"] # TODO: Fix issues to remove temporary upper bound +aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] -metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"] +libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] +metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] # NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution # is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI # release), so any `vlabench>=X` pip spec is unresolvable. Install it diff --git a/uv.lock b/uv.lock index 6acacab56..14b7f6f79 100644 --- a/uv.lock +++ b/uv.lock @@ -2712,6 +2712,7 @@ all = [ { name = "mock-serial", marker = "sys_platform != 'win32'" }, { name = "motorbridge" }, { name = "motorbridge-smart-servo" }, + { name = "mujoco" }, { name = "mypy" }, { name = "num2words" }, { name = "pandas" }, @@ -2749,6 +2750,7 @@ aloha = [ { name = "datasets" }, { name = "gym-aloha" }, { name = "jsonlines" }, + { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2864,6 +2866,7 @@ hilserl = [ { name = "grpcio" }, { name = "gym-hil" }, { name = "jsonlines" }, + { name = "mujoco" }, { name = "pandas" }, { name = "placo" }, { name = "protobuf" }, @@ -2895,6 +2898,7 @@ libero = [ { name = "datasets" }, { name = "hf-libero", marker = "sys_platform == 'linux'" }, { name = "jsonlines" }, + { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2910,6 +2914,7 @@ metaworld = [ { name = "datasets" }, { name = "jsonlines" }, { name = "metaworld" }, + { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2926,6 +2931,9 @@ motorbridge-dep = [ motorbridge-smart-servo-dep = [ { name = "motorbridge-smart-servo" }, ] +mujoco-dep = [ + { name = "mujoco" }, +] multi-task-dit = [ { name = "diffusers" }, { name = "transformers" }, @@ -3150,6 +3158,10 @@ requires-dist = [ { name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" }, + { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'aloha'" }, + { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'hilserl'" }, + { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'libero'" }, + { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'metaworld'" }, { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["notebook"], marker = "extra == 'dev'" }, { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" }, @@ -3223,6 +3235,7 @@ requires-dist = [ { name = "mock-serial", marker = "sys_platform != 'win32' and extra == 'test'", specifier = ">=0.0.1,<0.1.0" }, { name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" }, { name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" }, + { name = "mujoco", marker = "extra == 'mujoco-dep'", specifier = "<3.9.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" }, { name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" }, { name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" }, @@ -3276,7 +3289,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "mujoco-dep", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt" From 507083249f9715b2ad3893fbe9757903055aa644 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 10 Jun 2026 10:38:42 +0200 Subject: [PATCH 03/27] Revert "fix(pyproject): adding ceiling bound on mujoco (<3.9.0) (#3751)" (#3754) This reverts commit bd22407d9390f6973868fed9a883b4f295db105e. --- pyproject.toml | 9 ++++----- uv.lock | 15 +-------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9690a0d2c..2b4c22f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot topreward = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]", "lerobot[mujoco-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] # Features @@ -231,11 +231,10 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation # NOTE: Explicitly listing scipy helps flatten the dependecy tree. -mujoco-dep = ["mujoco<3.9.0"] # TODO: Fix issues to remove temporary upper bound -aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] +aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] -metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]", "lerobot[mujoco-dep]"] +libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] +metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"] # NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution # is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI # release), so any `vlabench>=X` pip spec is unresolvable. Install it diff --git a/uv.lock b/uv.lock index 14b7f6f79..6acacab56 100644 --- a/uv.lock +++ b/uv.lock @@ -2712,7 +2712,6 @@ all = [ { name = "mock-serial", marker = "sys_platform != 'win32'" }, { name = "motorbridge" }, { name = "motorbridge-smart-servo" }, - { name = "mujoco" }, { name = "mypy" }, { name = "num2words" }, { name = "pandas" }, @@ -2750,7 +2749,6 @@ aloha = [ { name = "datasets" }, { name = "gym-aloha" }, { name = "jsonlines" }, - { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2866,7 +2864,6 @@ hilserl = [ { name = "grpcio" }, { name = "gym-hil" }, { name = "jsonlines" }, - { name = "mujoco" }, { name = "pandas" }, { name = "placo" }, { name = "protobuf" }, @@ -2898,7 +2895,6 @@ libero = [ { name = "datasets" }, { name = "hf-libero", marker = "sys_platform == 'linux'" }, { name = "jsonlines" }, - { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2914,7 +2910,6 @@ metaworld = [ { name = "datasets" }, { name = "jsonlines" }, { name = "metaworld" }, - { name = "mujoco" }, { name = "pandas" }, { name = "pyarrow" }, { name = "scipy" }, @@ -2931,9 +2926,6 @@ motorbridge-dep = [ motorbridge-smart-servo-dep = [ { name = "motorbridge-smart-servo" }, ] -mujoco-dep = [ - { name = "mujoco" }, -] multi-task-dit = [ { name = "diffusers" }, { name = "transformers" }, @@ -3158,10 +3150,6 @@ requires-dist = [ { name = "lerobot", extras = ["molmoact2"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["motorbridge-dep"], marker = "extra == 'rebot'" }, { name = "lerobot", extras = ["motorbridge-smart-servo-dep"], marker = "extra == 'rebot'" }, - { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'aloha'" }, - { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'hilserl'" }, - { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'libero'" }, - { name = "lerobot", extras = ["mujoco-dep"], marker = "extra == 'metaworld'" }, { name = "lerobot", extras = ["multi-task-dit"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["notebook"], marker = "extra == 'dev'" }, { name = "lerobot", extras = ["openarms"], marker = "extra == 'all'" }, @@ -3235,7 +3223,6 @@ requires-dist = [ { name = "mock-serial", marker = "sys_platform != 'win32' and extra == 'test'", specifier = ">=0.0.1,<0.1.0" }, { name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" }, { name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" }, - { name = "mujoco", marker = "extra == 'mujoco-dep'", specifier = "<3.9.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" }, { name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" }, { name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" }, @@ -3289,7 +3276,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "mujoco-dep", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt" From 79c68214070e392f04800ed092dd89e0b761f44e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 10 Jun 2026 12:58:55 +0200 Subject: [PATCH 04/27] chore(dependecies): update mujoco transitives (#3756) --- pyproject.toml | 6 +++--- uv.lock | 25 ++++++++++++++----------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b4c22f12..f72cfa6dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,7 +216,7 @@ robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot topreward = ["lerobot[transformers-dep]"] xvla = ["lerobot[transformers-dep]"] eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"] -hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] +hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"] # Features @@ -231,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation # NOTE: Explicitly listing scipy helps flatten the dependecy tree. -aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"] +aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"] pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] +libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"] metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"] # NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution # is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI diff --git a/uv.lock b/uv.lock index 6acacab56..3a7129dac 100644 --- a/uv.lock +++ b/uv.lock @@ -1764,7 +1764,7 @@ wheels = [ [[package]] name = "gym-aloha" -version = "0.1.3" +version = "0.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dm-control" }, @@ -1772,14 +1772,14 @@ dependencies = [ { name = "imageio", extra = ["ffmpeg"] }, { name = "mujoco" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/5e/4bb7204730501c2f645e0532a2df4339206948b2882f77cbf0eaf75bc5fe/gym_aloha-0.1.3.tar.gz", hash = "sha256:b794b246a2e6da6ce5f75e152f553fbd4412704bc217fe6311d0ede3bb72a75e", size = 443468, upload-time = "2025-10-09T14:02:35.024Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/c5/a5b8bdbddfcadec0b52b50e6d1a70325e09e6b594e5f55929d67d9122e2c/gym_aloha-0.1.4.tar.gz", hash = "sha256:0dc4e645045aeb3e74e3c320872d28df6dc93a8751d6ab2f266a2ca11323131f", size = 443466, upload-time = "2026-06-10T09:13:25.525Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/6c/10da397177c48ce360efa66ec21b10b10ef5fa2766256fcd8d7d9b5fa6fc/gym_aloha-0.1.3-py3-none-any.whl", hash = "sha256:a94e5747e71307897ded7ae17ed97fab05e814dcb714a16d320f110444f9d0c3", size = 447908, upload-time = "2025-10-09T14:02:33.253Z" }, + { url = "https://files.pythonhosted.org/packages/35/e3/3afd0e517a503aabe255bf65f5136490acb79c43189e8d56a3aa63081a10/gym_aloha-0.1.4-py3-none-any.whl", hash = "sha256:d9044290fbccddf0be4246b5287cf0eb6b9ddee545a3d222ce8d78c93ce7125e", size = 447908, upload-time = "2026-06-10T09:13:23.868Z" }, ] [[package]] name = "gym-hil" -version = "0.1.13" +version = "0.1.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gymnasium" }, @@ -1789,9 +1789,9 @@ dependencies = [ { name = "pygame" }, { name = "pynput" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f3/41/e89c87b3c66fb2f8ab5818bff4aa552977911eabaee7c12a8a336dcc406f/gym_hil-0.1.13.tar.gz", hash = "sha256:b9eab7a0acc811f181254e3ad72865830fdbb292c236895f374135d3d62f1b27", size = 5668001, upload-time = "2025-10-21T09:57:24.01Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/64/b5cfe59d6a69d20497218f01ad2bdaa2a5a72b850bdb1a445d804ecc9948/gym_hil-0.1.14.tar.gz", hash = "sha256:aeee688dcb3ec72e7bcbe604df4a3f990cce49c8a2da469dd67c3a4eeb4c6bbb", size = 5667991, upload-time = "2026-06-10T09:16:38.98Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/8d/9e3ab53f9aac7bd542f339efd0a9283fa76e034474987e0705379274dfcf/gym_hil-0.1.13-py3-none-any.whl", hash = "sha256:b6444fc43ce1a68ce403df14f99100d9c903ae05d822959e9cd0b76a50b93320", size = 5750805, upload-time = "2025-10-21T09:57:22.068Z" }, + { url = "https://files.pythonhosted.org/packages/72/97/a7a9c3886306a89046ba5c989bc8b79008e7ec973228bad1fa20d7a94bba/gym_hil-0.1.14-py3-none-any.whl", hash = "sha256:9a2799d47a4561e0b0bb8d37fb3d84934657240be328d13991ea06758726533d", size = 5750805, upload-time = "2026-06-10T09:16:36.827Z" }, ] [[package]] @@ -1881,7 +1881,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/e6/3e/ffad88145b342d5a9 [[package]] name = "hf-libero" -version = "0.1.3" +version = "0.1.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "bddl", marker = "sys_platform == 'linux'" }, @@ -1902,7 +1902,10 @@ dependencies = [ { name = "transformers", marker = "sys_platform == 'linux'" }, { name = "wandb", marker = "sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7e/ca/7f1c90aedcd067d608681cf03469ae548990ba0806f68a67927dcc801f04/hf_libero-0.1.3.tar.gz", hash = "sha256:0d6b9a215a658db86f66c03d063d6d877d2e9f96d2d326cfa9f43ba4da4a6d5a", size = 2960521, upload-time = "2025-11-03T17:58:00.003Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/aa/4e9eb8715e0bff9cb6553db563a35d253393097d446f82bd53575e8b253d/hf_libero-0.1.4.tar.gz", hash = "sha256:c058d67ad5a2b589529c14d614282ef4cca3a7763dafa134f58a6c9039657e34", size = 2961319, upload-time = "2026-06-10T09:56:13.994Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/79/c286b894c051988d062241682834df915c945bcf51009ffdffbe5ecf69bf/hf_libero-0.1.4-py3-none-any.whl", hash = "sha256:207f76e2f28bff30f78132223d8592fe8f64b1f8fd90ce7024948ada0d7e2c27", size = 3169084, upload-time = "2026-06-10T09:56:12.441Z" }, +] [[package]] name = "hf-xet" @@ -3090,12 +3093,12 @@ requires-dist = [ { name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" }, { name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" }, { name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" }, - { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" }, - { name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.13,<0.2.0" }, + { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" }, + { name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" }, { name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" }, { name = "gymnasium", specifier = ">=1.1.1,<2.0.0" }, { name = "hebi-py", marker = "extra == 'phone'", specifier = ">=2.8.0,<2.12.0" }, - { name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.3,<0.2.0" }, + { name = "hf-libero", marker = "sys_platform == 'linux' and extra == 'libero'", specifier = ">=0.1.4,<0.2.0" }, { name = "hidapi", marker = "extra == 'gamepad'", specifier = ">=0.14.0,<0.15.0" }, { name = "huggingface-hub", specifier = ">=1.0.0,<2.0.0" }, { name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" }, From 41166b39fb8bacdd8f916d700064c5f64892bc0a Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:07:42 +0200 Subject: [PATCH 05/27] fix(train): synchronize EpisodeAwareSampler shuffling across ranks and gate dataset download per node (#3768) * fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync In distributed training, accelerate can only synchronize the shuffle permutation across ranks when the sampler exposes a generator attribute. EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch shards relied on every rank's global CPU RNG staying in lockstep forever; any rank-asymmetric RNG consumption (e.g. eval rollouts on the main process only) silently desynced the permutations and ranks trained on overlapping/missing samples. * fix(train): seed sampler generator and gate dataset download per node - Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently. --- src/lerobot/datasets/sampler.py | 8 +++++++- src/lerobot/scripts/lerobot_train.py | 20 +++++++++++++++----- tests/datasets/test_sampler.py | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 2bf7ab922..64d871907 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -30,6 +30,7 @@ class EpisodeAwareSampler: drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, + generator: torch.Generator | None = None, ): """Sampler that optionally incorporates episode boundary information. @@ -41,6 +42,10 @@ class EpisodeAwareSampler: drop_n_first_frames: Number of frames to drop from the start of each episode. drop_n_last_frames: Number of frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. + generator: Generator used for shuffling. Exposing this attribute (even when None) lets + `accelerate` register it as the synchronized RNG in distributed training, so + every rank draws the same permutation and batch shards stay disjoint. When + None, shuffling falls back to the global torch RNG. """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") @@ -73,10 +78,11 @@ class EpisodeAwareSampler: self.indices = indices self.shuffle = shuffle + self.generator = generator def __iter__(self) -> Iterator[int]: if self.shuffle: - for i in torch.randperm(len(self.indices)): + for i in torch.randperm(len(self.indices), generator=self.generator): yield self.indices[i] else: for i in self.indices: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4ddef3105..3d210f00b 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: main process downloads first to avoid race conditions - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: each node's local main process downloads first to avoid + # race conditions (the global main process only exists on node 0, so gating on it would let + # all ranks of the other nodes download and build the Arrow cache concurrently). + if accelerator.is_local_main_process: + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset - if not is_main_process: + # Now all other processes can safely load the dataset from the local cache + if not accelerator.is_local_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # create dataloader for offline training if hasattr(active_cfg, "drop_n_last_frames"): shuffle = False + # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare + # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even + # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). + sampler_generator = torch.Generator() + if cfg.seed is not None: + sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, drop_n_last_frames=active_cfg.drop_n_last_frames, shuffle=True, + generator=sampler_generator, ) else: shuffle = True diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 8bb3be8e9..95429c7ec 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -114,6 +114,30 @@ def test_shuffle(): assert set(sampler) == {0, 1, 2, 3, 4, 5} +def test_shuffle_with_generator_is_deterministic(): + # Two samplers shuffling with same-seed generators must yield identical permutations. + # This is what keeps batch shards disjoint across ranks in distributed training, where + # accelerate synchronizes the sampler's generator state instead of the global torch RNG. + sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + assert list(sampler_a) == list(sampler_b) + + # Desyncing the global RNG must not affect the permutation. + sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) + order_before = list(sampler_c) + sampler_c.generator.manual_seed(42) + torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would + assert list(sampler_c) == order_before + + +def test_generator_attribute_defaults_to_none(): + # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, + # so the attribute must exist even when no generator is passed. + sampler = EpisodeAwareSampler([0], [6], shuffle=True) + assert sampler.generator is None + assert set(sampler) == {0, 1, 2, 3, 4, 5} + + def test_negative_drop_first_frames_raises(): with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"): EpisodeAwareSampler([0], [10], drop_n_first_frames=-1) From 6fbcf67249fffd4eed340f2936fa1b112ba23e82 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 11 Jun 2026 18:17:26 +0200 Subject: [PATCH 06/27] chore: update readme (#3774) * chore: update readme * chore: update authors in project readme --- README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9c40e8b34..fa3e9e1a3 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ action = model.select_action(obs) robot.send_action(action) ``` -**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1. +**Supported Hardware:** SO100, LeKiwi, Koch, HopeJR, OMX, EarthRover, Reachy2, Gamepads, Keyboards, Phones, OpenARM, Unitree G1, reBot B601. While these devices are natively integrated into the LeRobot codebase, the library is designed to be extensible. You can easily implement the Robot interface to utilize LeRobot's data collection, training, and visualization tools for your own custom robot. @@ -101,11 +101,13 @@ lerobot-train \ --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -| Category | Models | -| -------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) | -| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | -| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | +| Category | Models | +| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md), [Multitask DiT Policy](./docs/source/policy_multi_task_dit_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0](./docs/source/pi0.mdx), [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx), [EO-1](./docs/source/eo1.mdx), [MolmoAct2](./docs/source/molmoact2.mdx), [WALL-OSS](./docs/source/walloss.mdx) | +| **World Models** | [VLA-JEPA](./docs/source/vla_jepa.mdx) (more coming soon) | +| **Reward Models** | [SARM](./docs/source/sarm.mdx), [TOPReward](./docs/source/topreward.mdx), [Robometer](./docs/source/robometer.mdx) | Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub @@ -133,6 +135,7 @@ Learn how to implement your own simulation environment or benchmark and distribu - **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. +- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot. ## Citation @@ -140,7 +143,7 @@ If you use LeRobot in your project, please cite the GitHub repository to acknowl ```bibtex @misc{cadene2024lerobot, - author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Meftah, Khalil and Ellerbach, Maxime and Moss, Jess and Wolf, Thomas}, title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch}, howpublished = "\url{https://github.com/huggingface/lerobot}", year = {2024} From 1edc83a0eff88b116d0cfafe74999478c89899f9 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 11 Jun 2026 19:07:28 +0200 Subject: [PATCH 07/27] feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup (#3773) * feat(training): bump accelerate + use reduction types for tracked metrics in a multi rank setup * chore: address feedback --- pyproject.toml | 5 +- src/lerobot/scripts/lerobot_train.py | 51 +++++++++++++----- src/lerobot/utils/logging_utils.py | 51 +++++++++++++++++- tests/utils/test_logging_utils.py | 78 +++++++++++++++++++++++++++- uv.lock | 16 +++--- 5 files changed, 177 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f72cfa6dd..89200d1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ dataset = [ ] training = [ "lerobot[dataset]", - "accelerate>=1.10.0,<2.0.0", + "lerobot[accelerate-dep]", "wandb>=0.24.0,<0.25.0", ] hardware = [ @@ -142,6 +142,7 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"] # (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available. placo-dep = ["placo>=0.9.6,<0.9.16"] transformers-dep = ["transformers>=5.4.0,<5.6.0"] +accelerate-dep = ["accelerate>=1.14.0,<2.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] @@ -199,7 +200,7 @@ wallx = [ ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] molmoact2 = ["lerobot[transformers-dep]", "lerobot[peft-dep]", "lerobot[scipy-dep]"] -smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "lerobot[accelerate-dep]"] multi_task_dit = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]"] groot = [ "lerobot[transformers-dep]", diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3d210f00b..a35d4229d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -99,6 +99,9 @@ def update_policy( start_time = time.perf_counter() policy.train() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + # Compute sample weights if a weighter is provided sample_weights = None weight_stats = None @@ -158,6 +161,8 @@ def update_policy( train_metrics.grad_norm = grad_norm.item() train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.update_s = time.perf_counter() - start_time + if torch.cuda.is_available(): + train_metrics.gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3) return train_metrics, output_dict @@ -434,12 +439,22 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): policy.train() train_metrics = { - "loss": AverageMeter("loss", ":.3f"), + # Per-rank loss reflects only one shard of the global batch; mean recovers the loss DDP + # is actually optimizing. grad_norm and lr are already identical on every rank (post + # gradient sync / deterministic scheduler) so reducing them would be a no-op collective. + "loss": AverageMeter("loss", ":.3f", reduction="mean"), "grad_norm": AverageMeter("grdn", ":.3f"), "lr": AverageMeter("lr", ":0.1e"), - "update_s": AverageMeter("updt_s", ":.3f"), - "dataloading_s": AverageMeter("data_s", ":.3f"), + # Report the slowest rank for bottleneck-style timings so multi-GPU runs surface the + # true straggler instead of rank 0's view. + "update_s": AverageMeter("updt_s", ":.3f", reduction="max"), + "dataloading_s": AverageMeter("data_s", ":.3f", reduction="max"), + # Derived from the post-reduce max step time; set once per log window on the main rank. + "samples_per_s": AverageMeter("smp/s", ":.0f"), } + if torch.cuda.is_available(): + # max() because headroom is gated by the worst-case rank. + train_metrics["gpu_mem_gb"] = AverageMeter("mem_gb", ":.2f", reduction="max") # Keep global batch size for logging; MetricsTracker handles world size internally. effective_batch_size = cfg.batch_size * accelerator.num_processes @@ -491,21 +506,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: progbar.update(1) train_tracker.step() - is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 if is_log_step: - logging.info(train_tracker) - if wandb_logger: - wandb_log_dict = train_tracker.to_dict() - if output_dict: - wandb_log_dict.update(output_dict) - # Log sample weighting statistics if enabled - if sample_weighter is not None: - weighter_stats = sample_weighter.get_stats() - wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) - wandb_logger.log_dict(wandb_log_dict, step) + # Collective reduce must run on every rank, before the main-process gate below. + train_tracker.reduce_across_ranks() + if is_main_process: + # Cluster-wide throughput, derived from the already-reduced (max) step time so it + # reflects the slowest rank — which is what actually gates the next iteration. + step_time = train_tracker.update_s.avg + train_tracker.dataloading_s.avg + if step_time > 0: + train_tracker.samples_per_s = effective_batch_size / step_time + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = train_tracker.to_dict() + if output_dict: + wandb_log_dict.update(output_dict) + # Log sample weighting statistics if enabled + if sample_weighter is not None: + weighter_stats = sample_weighter.get_stats() + wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) + wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: diff --git a/src/lerobot/utils/logging_utils.py b/src/lerobot/utils/logging_utils.py index 0ce596f55..20673fc30 100644 --- a/src/lerobot/utils/logging_utils.py +++ b/src/lerobot/utils/logging_utils.py @@ -13,21 +13,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from collections.abc import Callable from typing import Any +import torch + from .utils import format_big_number +_VALID_REDUCTIONS = ("none", "max", "mean", "sum") + class AverageMeter: """ Computes and stores the average and current value Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py + + Args: + name: Display name of the metric. + fmt: Format string used when rendering the metric. + reduction: Cross-process reduction applied by :meth:`MetricsTracker.reduce_across_ranks` + before logging. One of ``"none"`` (per-rank value, default), ``"max"``, ``"mean"``, + or ``"sum"``. Use ``"max"`` for bottleneck-style metrics (e.g. dataloading or + update wall time) so multi-GPU runs report the slowest rank rather than rank 0. """ - def __init__(self, name: str, fmt: str = ":f"): + def __init__(self, name: str, fmt: str = ":f", reduction: str = "none"): + if reduction not in _VALID_REDUCTIONS: + raise ValueError( + f"Invalid reduction {reduction!r} for AverageMeter; expected one of {_VALID_REDUCTIONS}." + ) self.name = name self.fmt = fmt + self.reduction = reduction self.reset() def reset(self) -> None: @@ -138,6 +156,37 @@ class MetricsTracker: self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames + def reduce_across_ranks(self) -> None: + """ + Synchronises the running averages of every metric whose ``reduction`` is not ``"none"`` + across all distributed processes (in-place). + + This is a collective operation and MUST be invoked on every rank — typically just before + logging. With no accelerator or in single-process runs it is a no-op. Without it, metrics + reported by the main process only reflect rank 0; for bottleneck-style timings + (``dataloading_s``, ``update_s``, ...) that means the slowest worker's stall is invisible. + """ + if self.accelerator is None or self.accelerator.num_processes <= 1: + return + + buckets: dict[str, list[str]] = defaultdict(list) + for name, meter in self.metrics.items(): + if meter.reduction != "none": + buckets[meter.reduction].append(name) + if not buckets: + return + + device = self.accelerator.device + for reduction, names in buckets.items(): + tensor = torch.tensor([self.metrics[n].avg for n in names], dtype=torch.float32, device=device) + reduced = self.accelerator.reduce(tensor, reduction=reduction) + for name, value in zip(names, reduced.tolist(), strict=True): + meter = self.metrics[name] + # Preserve avg == sum / count so a later .update() on this meter accumulates + # against the cluster view, not the stale per-rank history. + meter.avg = value + meter.sum = value * meter.count + def __str__(self) -> str: display_list = [ f"step:{format_big_number(self.steps)}", diff --git a/tests/utils/test_logging_utils.py b/tests/utils/test_logging_utils.py index 1207534c0..aa851bd2a 100644 --- a/tests/utils/test_logging_utils.py +++ b/tests/utils/test_logging_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import pytest +import torch from lerobot.utils.logging_utils import AverageMeter, MetricsTracker @@ -25,8 +26,16 @@ def mock_metrics(): class MockAccelerator: - def __init__(self, num_processes: int): + def __init__(self, num_processes: int, reduce_fn=None): self.num_processes = num_processes + self.device = torch.device("cpu") + self._reduce_fn = reduce_fn + + def reduce(self, tensor, reduction="mean"): + # In single-process tests we just want a deterministic stand-in for accelerate's reduce. + if self._reduce_fn is not None: + return self._reduce_fn(tensor, reduction) + return tensor def test_average_meter_initialization(): @@ -157,3 +166,70 @@ def test_metrics_tracker_reset_averages(mock_metrics): tracker.reset_averages() assert tracker.loss.avg == 0.0 assert tracker.accuracy.avg == 0.0 + + +def test_average_meter_invalid_reduction(): + with pytest.raises(ValueError): + AverageMeter("loss", reduction="median") + + +def test_average_meter_reduction_stored(): + meter = AverageMeter("updt_s", reduction="max") + assert meter.reduction == "max" + + +def test_metrics_tracker_reduce_across_ranks_no_accelerator(): + metrics = {"update_s": AverageMeter("update_s", reduction="max")} + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=metrics) + tracker.update_s = 0.5 + tracker.reduce_across_ranks() # no-op without accelerator + assert tracker.update_s.avg == 0.5 + + +def test_metrics_tracker_reduce_across_ranks_single_process(): + metrics = {"update_s": AverageMeter("update_s", reduction="max")} + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=metrics, + accelerator=MockAccelerator(num_processes=1), + ) + tracker.update_s = 0.5 + tracker.reduce_across_ranks() # no-op when world size is 1 + assert tracker.update_s.avg == 0.5 + + +def test_metrics_tracker_reduce_across_ranks_invokes_reduce(): + captured = {} + + def fake_reduce(tensor, reduction): + captured["reduction"] = reduction + captured["values"] = tensor.clone() + # Pretend the slowest rank reported 0.9 instead of this rank's 0.4. + return torch.tensor([0.9], dtype=tensor.dtype, device=tensor.device) + + metrics = { + "loss": AverageMeter("loss"), # reduction="none" -> not touched + "update_s": AverageMeter("update_s", reduction="max"), + } + tracker = MetricsTracker( + batch_size=32, + num_frames=1000, + num_episodes=50, + metrics=metrics, + accelerator=MockAccelerator(num_processes=4, reduce_fn=fake_reduce), + ) + tracker.loss = 1.0 + tracker.update_s = 0.4 + tracker.reduce_across_ranks() + + assert captured["reduction"] == "max" + assert torch.allclose(captured["values"], torch.tensor([0.4])) + assert tracker.update_s.avg == pytest.approx(0.9) + # Metrics without a reduction stay untouched. + assert tracker.loss.avg == 1.0 + # Invariant: avg == sum / count must hold after reduce, so subsequent .update() calls + # accumulate against the cluster view rather than the stale per-rank sum. + meter = tracker.update_s + assert meter.sum / meter.count == pytest.approx(meter.avg) diff --git a/uv.lock b/uv.lock index 3a7129dac..f4f854b62 100644 --- a/uv.lock +++ b/uv.lock @@ -59,7 +59,7 @@ wheels = [ [[package]] name = "accelerate" -version = "1.13.0" +version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, @@ -71,9 +71,9 @@ dependencies = [ { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform != 'linux'" }, { name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/75/94cd5d389649578aca399e5aa822637eec18319a1dadc400ffe2f9a7493f/accelerate-1.14.0.tar.gz", hash = "sha256:41b9c4377a54e0b460a959b0defa1b736e4ca0a2373252d9a539964c2afe3c8d", size = 412167, upload-time = "2026-06-11T13:45:52.326Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" }, + { url = "https://files.pythonhosted.org/packages/a8/db/253133d7e7cb40d3af384bb2f5c0b4a2b7fdcffbc95c688cc67a20a3c103/accelerate-1.14.0-py3-none-any.whl", hash = "sha256:e94390c2863b873be18f623f9df48a0d8fe5eff13ea7f1a00092b0a7904888c6", size = 389246, upload-time = "2026-06-11T13:45:50.477Z" }, ] [[package]] @@ -2687,6 +2687,9 @@ dependencies = [ ] [package.optional-dependencies] +accelerate-dep = [ + { name = "accelerate" }, +] all = [ { name = "accelerate" }, { name = "av" }, @@ -3073,8 +3076,7 @@ xvla = [ [package.metadata] requires-dist = [ - { name = "accelerate", marker = "extra == 'smolvla'", specifier = ">=1.7.0,<2.0.0" }, - { name = "accelerate", marker = "extra == 'training'", specifier = ">=1.10.0,<2.0.0" }, + { name = "accelerate", marker = "extra == 'accelerate-dep'", specifier = ">=1.14.0,<2.0.0" }, { name = "av", marker = "extra == 'av-dep'", specifier = ">=15.0.0,<16.0.0" }, { name = "cmake", specifier = ">=3.29.0.1,<4.2.0" }, { name = "contourpy", marker = "extra == 'matplotlib-dep'", specifier = ">=1.3.0,<2.0.0" }, @@ -3104,6 +3106,8 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'notebook'", specifier = ">=6.0.0,<7.0.0" }, { name = "jsonlines", marker = "extra == 'dataset'", specifier = ">=4.0.0,<5.0.0" }, { name = "jupyter", marker = "extra == 'notebook'", specifier = ">=1.0.0,<2.0.0" }, + { name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'smolvla'" }, + { name = "lerobot", extras = ["accelerate-dep"], marker = "extra == 'training'" }, { name = "lerobot", extras = ["aloha"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["async"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["av-dep"], marker = "extra == 'dataset'" }, @@ -3279,7 +3283,7 @@ requires-dist = [ { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ] -provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] +provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "accelerate-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] [[package]] name = "librt" From 87242cfced5228b6181fb0047021d684d424df38 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 11 Jun 2026 19:13:14 +0200 Subject: [PATCH 08/27] chore(dependecies): relax grpc-related bounds (#3777) Signed-off-by: Steven Palma --- pyproject.toml | 13 +++++++++---- uv.lock | 12 ++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 89200d1ab..e43f8ef81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,8 +115,8 @@ dataset = [ ] training = [ "lerobot[dataset]", + "wandb>=0.24.0,<0.28.0", "lerobot[accelerate-dep]", - "wandb>=0.24.0,<0.25.0", ] hardware = [ "lerobot[pynput-dep]", @@ -142,8 +142,8 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"] # (noble ships urdfdom 3.x). Cap below 0.9.16 until system urdfdom 4.x is broadly available. placo-dep = ["placo>=0.9.6,<0.9.16"] transformers-dep = ["transformers>=5.4.0,<5.6.0"] +grpcio-dep = ["grpcio>=1.73.1,<2.0.0", "protobuf>=6.31.1,<8.0.0"] accelerate-dep = ["accelerate>=1.14.0,<2.0.0"] -grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] can-dep = ["python-can>=4.2.0,<5.0.0"] peft-dep = ["peft>=0.18.0,<1.0.0"] scipy-dep = ["scipy>=1.14.0,<2.0.0"] @@ -178,7 +178,12 @@ unitree_g1 = [ "lerobot[matplotlib-dep]", "lerobot[pygame-dep]", ] -reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] +# reachy2-sdk caps grpcio<=1.73.1 and protobuf<=6.32.0; quarantined here so downstream users aren't held back. reachy2-sdk is unlikely to release new versions. +reachy2 = [ + "reachy2_sdk>=1.0.15,<1.1.0", + "grpcio<=1.73.1", + "protobuf<=6.32.0", +] # Seeed Studio reBot B601-DM follower (motorbridge / CAN) + StarArm102 / reBot Arm 102 # leader (motorbridge-smart-servo / FashionStar UART servos). rebot = ["lerobot[motorbridge-dep]", "lerobot[motorbridge-smart-servo-dep]"] @@ -225,7 +230,7 @@ async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] # Development -dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] +dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"] test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0.0,<8.0.0", "mock-serial>=0.0.1,<0.1.0 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] diff --git a/uv.lock b/uv.lock index f4f854b62..4072828e7 100644 --- a/uv.lock +++ b/uv.lock @@ -2989,6 +2989,8 @@ qwen-vl-utils-dep = [ { name = "qwen-vl-utils" }, ] reachy2 = [ + { name = "grpcio" }, + { name = "protobuf" }, { name = "reachy2-sdk" }, ] rebot = [ @@ -3093,8 +3095,9 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" }, { name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" }, { name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" }, - { name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" }, - { name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" }, + { name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = ">=1.73.1,<2.0.0" }, + { name = "grpcio", marker = "extra == 'reachy2'", specifier = "<=1.73.1" }, + { name = "grpcio-tools", marker = "extra == 'dev'", specifier = ">=1.73.1,<2.0.0" }, { name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.4,<0.2.0" }, { name = "gym-hil", marker = "extra == 'hilserl'", specifier = ">=0.1.14,<0.2.0" }, { name = "gym-pusht", marker = "extra == 'pusht'", specifier = ">=0.1.5,<0.2.0" }, @@ -3244,7 +3247,8 @@ requires-dist = [ { name = "pillow", specifier = ">=10.0.0,<13.0.0" }, { name = "placo", marker = "extra == 'placo-dep'", specifier = ">=0.9.6,<0.9.16" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.7.0,<5.0.0" }, - { name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<6.32.0" }, + { name = "protobuf", marker = "extra == 'grpcio-dep'", specifier = ">=6.31.1,<8.0.0" }, + { name = "protobuf", marker = "extra == 'reachy2'", specifier = "<=6.32.0" }, { name = "pyarrow", marker = "extra == 'dataset'", specifier = ">=21.0.0,<30.0.0" }, { name = "pydantic", marker = "extra == 'sarm'", specifier = ">=2.0.0,<3.0.0" }, { name = "pygame", marker = "extra == 'pygame-dep'", specifier = ">=2.5.1,<2.7.0" }, @@ -3281,7 +3285,7 @@ requires-dist = [ { name = "torchvision", marker = "sys_platform == 'linux'", specifier = ">=0.22.0,<0.27.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "tqdm", specifier = ">=4.66.0,<5.0.0" }, { name = "transformers", marker = "extra == 'transformers-dep'", specifier = ">=5.4.0,<5.6.0" }, - { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, + { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.28.0" }, ] provides-extras = ["dataset", "training", "hardware", "viz", "core-scripts", "evaluation", "dataset-viz", "av-dep", "pygame-dep", "placo-dep", "transformers-dep", "accelerate-dep", "grpcio-dep", "can-dep", "peft-dep", "scipy-dep", "diffusers-dep", "qwen-vl-utils-dep", "matplotlib-dep", "pyserial-dep", "deepdiff-dep", "pynput-dep", "pyzmq-dep", "motorbridge-dep", "motorbridge-smart-servo-dep", "feetech", "dynamixel", "damiao", "robstride", "openarms", "gamepad", "hopejr", "lekiwi", "unitree-g1", "reachy2", "rebot", "kinematics", "intelrealsense", "phone", "diffusion", "wallx", "pi", "molmoact2", "smolvla", "multi-task-dit", "groot", "sarm", "robometer", "topreward", "xvla", "eo1", "hilserl", "vla-jepa", "async", "peft", "dev", "notebook", "test", "video-benchmark", "aloha", "pusht", "libero", "metaworld", "all"] From 0e9bd9e6fb7c4c82c6be38e3fe82103354df1b40 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 12 Jun 2026 11:29:26 +0200 Subject: [PATCH 09/27] feat(trim): adding optional trimming option in reencode_video (#3779) * feat(trim): adding optional trimming option in reencode_video * tests(trim): add triming test --------- Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- src/lerobot/datasets/video_utils.py | 22 +++++++++++++++++++++- tests/datasets/test_video_encoding.py | 13 +++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 84ab56e08..ca90fba45 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -481,8 +481,10 @@ def reencode_video( encoder_threads: int | None = None, log_level: int | None = av.logging.WARNING, overwrite: bool = False, + start_time_s: float | None = None, + end_time_s: float | None = None, ) -> None: - """Re-encode a video file using the given encoder configuration. + """Re-encode a video file, optionally trimming it to ``[start_time_s, end_time_s)``. Args: input_video_path: Existing video file to read. @@ -491,10 +493,17 @@ def reencode_video( encoder_threads: Optional thread count forwarded to :meth:`VideoEncoderConfig.get_codec_options`. log_level: libav log level while encoding, or ``None`` to leave logging unchanged. Defaults to WARNING. overwrite: When ``False`` and ``output_video_path`` already exists, skip and log a warning. + start_time_s: When set, trim the output to start at this timestamp (seconds). + end_time_s: When set, trim the output to end at this timestamp (seconds, exclusive). """ camera_encoder = camera_encoder or camera_encoder_defaults() + if (start_time_s is not None and start_time_s < 0) or (end_time_s is not None and end_time_s < 0): + raise ValueError(f"Trim times must be non-negative, got start={start_time_s}, end={end_time_s}.") + if start_time_s is not None and end_time_s is not None and end_time_s <= start_time_s: + raise ValueError(f"end_time_s ({end_time_s}) must be greater than start_time_s ({start_time_s}).") + output_video_path = Path(output_video_path) if output_video_path.exists() and not overwrite: @@ -526,6 +535,10 @@ def reencode_video( width = int(in_stream.width) height = int(in_stream.height) + # Seek to the keyframe at or before start_time_s to avoid reading from the start. + if start_time_s is not None: + src.seek(int(start_time_s * av.time_base), backward=True) + with av.open( tmp_output_video_path, mode="w", @@ -539,7 +552,14 @@ def reencode_video( out_stream.height = height for frame in src.decode(in_stream): + frame_time_s = frame.time + if start_time_s is not None and frame_time_s < start_time_s: + continue + if end_time_s is not None and frame_time_s >= end_time_s: + break frame = frame.reformat(width=width, height=height, format=pix_fmt) + if start_time_s is not None: + frame.pts = None # reset timestamps so the trimmed output starts at t=0 packet = out_stream.encode(frame) if packet: dst.mux(packet) diff --git a/tests/datasets/test_video_encoding.py b/tests/datasets/test_video_encoding.py index 1af61e9f9..2a35f3210 100644 --- a/tests/datasets/test_video_encoding.py +++ b/tests/datasets/test_video_encoding.py @@ -504,6 +504,19 @@ class TestReencodeVideo: assert info["video.g"] == 6 assert info["video.crf"] == 23 + @require_h264 + def test_reencode_video_trim_window(self, tmp_path): + src = TEST_ARTIFACTS_DIR / "clip_6frames.mp4" + out = tmp_path / "trim_window.mp4" + cfg = VideoEncoderConfig(vcodec="h264") + reencode_video(src, out, camera_encoder=cfg, start_time_s=0.05, end_time_s=0.12, overwrite=True) + + with av.open(str(out)) as container: + frames = list(container.decode(video=0)) + # Only the frames at 0.067 and 0.1 s fall inside [0.05, 0.12). + assert len(frames) == 2 + assert frames[0].time == pytest.approx(0.0, abs=1e-3) + class TestConcatenateVideoFiles: def test_two_clips_frame_count(self, tmp_path): From 234c768dfb6fbf7250602003dec0f91b7bfdeb86 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:47:16 +0200 Subject: [PATCH 10/27] feat(datasets): deterministic, resumable shuffling for EpisodeAwareSampler (#3769) * fix(datasets): expose a generator on EpisodeAwareSampler for distributed shuffle sync In distributed training, accelerate can only synchronize the shuffle permutation across ranks when the sampler exposes a generator attribute. EpisodeAwareSampler shuffled via the global torch RNG, so disjoint batch shards relied on every rank's global CPU RNG staying in lockstep forever; any rank-asymmetric RNG consumption (e.g. eval rollouts on the main process only) silently desynced the permutations and ranks trained on overlapping/missing samples. Co-Authored-By: Claude Fable 5 * fix(train): seed sampler generator and gate dataset download per node - Pass a generator seeded with cfg.seed to EpisodeAwareSampler so accelerator.prepare registers it as the synchronized RNG and the shuffle order is reproducible. - Gate the initial make_dataset call on is_local_main_process instead of is_main_process: the global main process only exists on node 0, so on every other node all local ranks were downloading the dataset and building the Arrow cache concurrently. Co-Authored-By: Claude Fable 5 * feat(datasets): add DeterministicEpisodeAwareSampler with O(1) memory and sample-exact resume Add a sampler that never materializes frame indices: it stores only per-episode boundaries (numpy, a few bytes per episode) and maps logical positions to frame indices on the fly with searchsorted. Shuffling uses a seeded Feistel permutation over [0, num_frames) (cycle-walking to the exact domain), so the data order is a pure function of (seed, epoch): - no RNG state to synchronize across distributed ranks, - constant memory and zero epoch-boundary cost at any dataset size, - O(1) seek to any position, enabling sample-exact resume. Opt in with --deterministic_sampler=true. On resume, lerobot-train maps the checkpointed step back to (epoch, start_index) via compute_sampler_state and continues at the exact sample where the run left off (up to accelerate's even_batches padding at epoch boundaries). The shuffle is pseudo-random rather than a true uniform permutation, the standard trade-off in large-scale training loaders. Co-Authored-By: Claude Fable 5 * refactor(datasets): fold deterministic mode into EpisodeAwareSampler Instead of a parallel DeterministicEpisodeAwareSampler class, extend the existing EpisodeAwareSampler with a deterministic=True mode (seeded Feistel permutation, epoch auto-advance, state_dict/load_state_dict). The default mode is behavior-identical: same torch.randperm consumption and the same generator contract accelerate synchronizes; the O(N) Python index list is replaced by O(num_episodes) boundary arrays in both modes, with `indices` kept as a back-compat property. Passing a generator together with deterministic=True is rejected, and the state/seek methods raise outside deterministic mode. Co-Authored-By: Claude Fable 5 * feat(train): enable deterministic_sampler by default Deterministic data order (sample-exact resume, no cross-rank RNG sync, O(1) sampler memory) is now the default for map-style training; set deterministic_sampler=false to restore the legacy RNG-based shuffle. Streaming datasets ignore the flag (the sampler path only applies to map-style datasets), replacing the previous hard validation error so streaming configs keep working with the new default. Co-Authored-By: Claude Fable 5 * feat(datasets): default EpisodeAwareSampler to deterministic mode and trim comments deterministic=True is now the class default as well as the training default; the legacy RNG path requires an explicit deterministic=False (the train script's non-deterministic branch passes it). Docstrings and inline comments slimmed down across the changed files. Co-Authored-By: Claude Fable 5 * test(sampler): drain resumed trillion-frame sampler via iter() to avoid list() prealloc list(sampler) calls PyObject_LengthHint -> __len__ (the full 10**12 epoch length) and preallocates that many slots before iterating, OOMing even though the resumed epoch only yields 3 frames. Collect through the iterator (no length hint) so the test exercises the real O(1) seek/drain instead of CPython's list growth heuristic. * fix(datasets): guard Feistel cycle-walking loop against non-convergence Replace the unbounded while True in EpisodeAwareSampler._permute with a bounded for loop capped at _MAX_CYCLE_WALK_STEPS (100) and raise RuntimeError if the cycle-walk fails to land in [0, num_frames). The loop is expected to converge in <4 steps on the chosen power-of-two domain, so the bound is a safety net that should never trip in practice but prevents a pathological infinite loop. https://claude.ai/code/session_01HQ15tFrBsHYScjGWosEv22 * fix(datasets): make deterministic-sampler resume robust to world-size changes compute_sampler_state mapped a checkpointed step back to (epoch, start_index) using the *current* num_processes, but the number of sampler positions a step consumes scales with the world size that produced it. Resuming on a different GPU count therefore landed on the wrong epoch/offset, silently re-seeing or skipping data. Record num_processes in training_step.json at checkpoint time and feed the checkpoint's value into compute_sampler_state on resume, so the data order resumes at the right position regardless of the new world size. Warn when the world size changed (the global offset is correct, but per-rank sample-exactness needs the same topology). Old checkpoints without the field fall back to the current world size. Also document compute_sampler_state's assumptions explicitly: num_processes / batch_size must match the checkpointing run, and accelerate's even_batches=True padding is mirrored by the ceil(... / num_processes) term. Co-Authored-By: Claude Fable 5 Co-authored-by: Cursor * style: apply ruff-format to lerobot_train.py Collapse the compute_sampler_state(...) call onto one line so the ruff-format pre-commit hook passes (fixes the failing CI check). Co-authored-by: Cursor * refactor(datasets): use seeded torch.randperm instead of Feistel in EpisodeAwareSampler Drop the Feistel permutation (and its SplitMix64 hash / cycle-walking) in favor of a torch.randperm seeded from (seed, epoch). The deterministic mode keeps its key properties - data order is a pure function of (seed, epoch), so it reproduces on every rank with no global-RNG synchronization, and - state_dict / load_state_dict still resume sample-exactly, now by regenerating the epoch's permutation and slicing from the saved offset. Construction stays O(num_episodes) (only episode boundaries are stored, never a per-frame index list). The trade-off vs Feistel: the per-epoch shuffle is again O(num_frames) memory (the randperm tensor) and no longer O(1)-seekable, in exchange for ~30 fewer LOC and a truly uniform shuffle. Tests updated: the trillion-frame O(1) test is replaced with a boundary-storage check and a scale resume-exactness test. Co-authored-by: Cursor * refactor(datasets): make EpisodeAwareSampler always deterministic With Feistel gone, deterministic and legacy modes were both just torch.randperm and the deterministic path strictly dominated (reproducible across ranks via the (seed, epoch) seed, no accelerate generator sync, resumable). Collapse to a single path and drop the redundant flag: - remove the `deterministic` and `generator` constructor args, `_iter_default`, and `_require_deterministic`; `set_epoch` / `state_dict` / `load_state_dict` are now unconditional - remove the `deterministic_sampler` train config field and the legacy generator branch in lerobot_train.py (non-streaming map datasets always use the sampler) - drop the now-obsolete generator/legacy tests Note: removes the `generator` kwarg from EpisodeAwareSampler (back-compat break vs main); the order is now a pure function of (seed, epoch), so no cross-rank RNG sync is needed. Co-authored-by: Cursor * fix(datasets): address sampler review (batch_size resume guard + docs) - Record batch_size in training_step.json alongside num_processes and feed the checkpoint's value into compute_sampler_state on resume; warn when it differs (per-rank sample-exactness needs the same batch size). - Document the set_epoch vs __iter__ auto-advance coupling on EpisodeAwareSampler (callers should rely on exactly one mechanism per run). - Note the broadened (reproducibility-breaking) sampler guard and the no-generator distributed sharding correctness in lerobot_train.py. - Add load_training_batch_size + parallel tests. Co-authored-by: Cursor * fix(train): download dataset once on the global main process Gate the training dataset download on the global is_main_process (download once to the shared dataset root, barrier, then every other rank reads the already-populated copy) instead of per-node is_local_main_process. LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. Assumes the dataset root / HF cache is on storage shared across nodes. Co-authored-by: Cursor * chore(datasets): trim sampler comment and drop duplicate tests Remove the verbose dataloader-guard comment and the two EpisodeAwareSampler tests that duplicated existing validation/warning coverage (no coverage loss). Co-authored-by: Cursor --------- Co-authored-by: Claude Fable 5 Co-authored-by: Cursor --- src/lerobot/common/train_utils.py | 41 ++++++- src/lerobot/datasets/__init__.py | 3 +- src/lerobot/datasets/sampler.py | 160 ++++++++++++++++++++------- src/lerobot/scripts/lerobot_train.py | 64 ++++++++--- tests/datasets/test_sampler.py | 113 +++++++++++++++---- tests/utils/test_train_utils.py | 24 ++++ 6 files changed, 324 insertions(+), 81 deletions(-) diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 21ee514de..2d23b4003 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -49,8 +49,19 @@ def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Pa return output_dir / CHECKPOINTS_DIR / step_identifier -def save_training_step(step: int, save_dir: Path) -> None: - write_json({"step": step}, save_dir / TRAINING_STEP) +def save_training_step( + step: int, save_dir: Path, num_processes: int | None = None, batch_size: int | None = None +) -> None: + state: dict = {"step": step} + # num_processes and batch_size are recorded so a resumed run can detect a changed world size or + # batch size: the sampler's resume offset is computed from the (num_processes, batch_size) that + # produced `step`, since both scale how many sampler positions a step consumes (see + # compute_sampler_state). + if num_processes is not None: + state["num_processes"] = num_processes + if batch_size is not None: + state["batch_size"] = batch_size + write_json(state, save_dir / TRAINING_STEP) def load_training_step(save_dir: Path) -> int: @@ -58,6 +69,16 @@ def load_training_step(save_dir: Path) -> int: return training_step["step"] +def load_training_num_processes(checkpoint_dir: Path) -> int | None: + """World size recorded at checkpoint time, or None for checkpoints written before it was stored.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("num_processes") + + +def load_training_batch_size(checkpoint_dir: Path) -> int | None: + """Per-process batch size recorded at checkpoint time, or None for older checkpoints.""" + return load_json(checkpoint_dir / TRAINING_STATE_DIR / TRAINING_STEP).get("batch_size") + + def update_last_checkpoint(checkpoint_dir: Path) -> Path: last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK if last_checkpoint_dir.is_symlink(): @@ -75,6 +96,8 @@ def save_checkpoint( scheduler: LRScheduler | None = None, preprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None, + num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """This function creates the following directory structure: @@ -100,6 +123,10 @@ def save_checkpoint( scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. preprocessor: The preprocessor/pipeline to save. Defaults to None. postprocessor: The postprocessor/pipeline to save. Defaults to None. + num_processes (int | None, optional): Distributed world size to record for sample-exact + resume. Defaults to None (not recorded). + batch_size (int | None, optional): Per-process batch size to record for sample-exact + resume. Defaults to None (not recorded). """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) @@ -112,7 +139,9 @@ def save_checkpoint( preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: postprocessor.save_pretrained(pretrained_dir) - save_training_state(checkpoint_dir, step, optimizer, scheduler) + save_training_state( + checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size + ) def save_training_state( @@ -120,6 +149,8 @@ def save_training_state( train_step: int, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, + num_processes: int | None = None, + batch_size: int | None = None, ) -> None: """ Saves the training step, optimizer state, scheduler state, and rng state. @@ -131,10 +162,12 @@ def save_training_state( Defaults to None. scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. Defaults to None. + num_processes (int | None, optional): Distributed world size to record. Defaults to None. + batch_size (int | None, optional): Per-process batch size to record. Defaults to None. """ save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir.mkdir(parents=True, exist_ok=True) - save_training_step(train_step, save_dir) + save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size) save_rng_state(save_dir) if optimizer is not None: save_optimizer_state(optimizer, save_dir) diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index 2a67858d2..bd12a7248 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -50,7 +50,7 @@ from .lerobot_dataset import LeRobotDataset from .multi_dataset import MultiLeRobotDataset from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from .pyav_utils import check_video_encoder_parameters_pyav, detect_available_encoders_pyav -from .sampler import EpisodeAwareSampler +from .sampler import EpisodeAwareSampler, compute_sampler_state from .streaming_dataset import StreamingLeRobotDataset from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card from .video_utils import VideoEncodingManager @@ -82,6 +82,7 @@ __all__ = [ "aggregate_stats", "convert_image_to_video_dataset", "create_initial_features", + "compute_sampler_state", "create_lerobot_dataset_card", "column_for_style", "delete_episodes", diff --git a/src/lerobot/datasets/sampler.py b/src/lerobot/datasets/sampler.py index 64d871907..af85dff9b 100644 --- a/src/lerobot/datasets/sampler.py +++ b/src/lerobot/datasets/sampler.py @@ -14,14 +14,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from collections.abc import Iterator +import numpy as np import torch logger = logging.getLogger(__name__) class EpisodeAwareSampler: + """Sampler over episode frames that stores only per-episode boundaries. + + Logical positions map to frame indices on the fly (O(num_episodes) construction memory) + instead of materializing a Python list of every frame index. + + Each epoch is shuffled with a `torch.randperm` seeded from `(seed, epoch)`, so the data order + is a pure function of `(seed, epoch)`: it reproduces on every rank without synchronizing the + global RNG (no `generator` to sync across distributed ranks), and `state_dict` / + `load_state_dict` resume a run sample-exactly by regenerating the epoch's permutation and + continuing from the saved offset. Each call to `__iter__` advances the epoch. During a + resumed epoch, `__len__` still reports the full length. + + Epoch advancement: `__iter__` eagerly advances the epoch, and `set_epoch` / `load_state_dict` + set it explicitly. Within a single run callers should rely on exactly one of these mechanisms, + not both: advancing the epoch by hand *and* letting `__iter__` auto-advance over the same + iterations would skip or repeat epochs. The training loop drives it purely through `__iter__` + (via `cycle`); `set_epoch` / `load_state_dict` are used only to (re)position before iteration + starts (e.g. on resume or in tests). + """ + def __init__( self, dataset_from_indices: list[int], @@ -30,63 +52,125 @@ class EpisodeAwareSampler: drop_n_first_frames: int = 0, drop_n_last_frames: int = 0, shuffle: bool = False, - generator: torch.Generator | None = None, + seed: int = 0, ): - """Sampler that optionally incorporates episode boundary information. - + """ Args: - dataset_from_indices: List of indices containing the start of each episode in the dataset. - dataset_to_indices: List of indices containing the end of each episode in the dataset. - episode_indices_to_use: List of episode indices to use. If None, all episodes are used. - Assumes that episodes are indexed from 0 to N-1. - drop_n_first_frames: Number of frames to drop from the start of each episode. - drop_n_last_frames: Number of frames to drop from the end of each episode. + dataset_from_indices: Start index of each episode in the dataset. + dataset_to_indices: End index of each episode in the dataset. + episode_indices_to_use: Episode indices to use; None means all. + drop_n_first_frames: Frames to drop from the start of each episode. + drop_n_last_frames: Frames to drop from the end of each episode. shuffle: Whether to shuffle the indices. - generator: Generator used for shuffling. Exposing this attribute (even when None) lets - `accelerate` register it as the synchronized RNG in distributed training, so - every rank draws the same permutation and batch shards stay disjoint. When - None, shuffling falls back to the global torch RNG. + seed: Seed the permutation is derived from (together with the epoch). """ if drop_n_first_frames < 0: raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}") if drop_n_last_frames < 0: raise ValueError(f"drop_n_last_frames must be >= 0, got {drop_n_last_frames}") - indices = [] - for episode_idx, (start_index, end_index) in enumerate( - zip(dataset_from_indices, dataset_to_indices, strict=True) - ): - if episode_indices_to_use is None or episode_idx in episode_indices_to_use: - ep_length = end_index - start_index - if drop_n_first_frames + drop_n_last_frames >= ep_length: - logger.warning( - "Episode %d has %d frames but drop_n_first_frames=%d and " - "drop_n_last_frames=%d removes all frames. Skipping.", - episode_idx, - ep_length, - drop_n_first_frames, - drop_n_last_frames, - ) - continue - indices.extend(range(start_index + drop_n_first_frames, end_index - drop_n_last_frames)) + from_indices = np.asarray(dataset_from_indices, dtype=np.int64) + to_indices = np.asarray(dataset_to_indices, dtype=np.int64) + if from_indices.shape != to_indices.shape: + raise ValueError( + f"dataset_from_indices and dataset_to_indices must have the same length, " + f"got {len(from_indices)} and {len(to_indices)}" + ) - if not indices: + used = np.ones(len(from_indices), dtype=bool) + if episode_indices_to_use is not None: + used = np.zeros(len(from_indices), dtype=bool) + used[np.asarray(episode_indices_to_use, dtype=np.int64)] = True + + starts = from_indices + drop_n_first_frames + lengths = to_indices - drop_n_last_frames - starts + for episode_idx in np.flatnonzero(used & (lengths <= 0)): + logger.warning( + "Episode %d has %d frames but drop_n_first_frames=%d and " + "drop_n_last_frames=%d removes all frames. Skipping.", + episode_idx, + to_indices[episode_idx] - from_indices[episode_idx], + drop_n_first_frames, + drop_n_last_frames, + ) + used &= lengths > 0 + if not used.any(): raise ValueError( "No valid frames remain after applying drop_n_first_frames and drop_n_last_frames. " "All episodes were either filtered out or had too few frames." ) - self.indices = indices + self._starts = starts[used] + self._cum_lengths = np.cumsum(lengths[used]) + self._num_frames = int(self._cum_lengths[-1]) self.shuffle = shuffle - self.generator = generator + self.seed = seed + self._epoch = 0 + self._start_index = 0 + + @property + def indices(self) -> list[int]: + """Materialized frame indices in unshuffled order; O(num_frames), introspection only.""" + return [self._frame_index(k) for k in range(self._num_frames)] + + def set_epoch(self, epoch: int) -> None: + self._epoch = epoch + + def state_dict(self) -> dict: + return {"epoch": self._epoch, "start_index": self._start_index} + + def load_state_dict(self, state: dict) -> None: + self._epoch = state["epoch"] + self._start_index = state["start_index"] + + def _epoch_generator(self, epoch: int) -> torch.Generator: + # Derive a per-epoch seed from (seed, epoch) so the permutation is a pure function of both + # and reproduces identically on every rank without touching the global RNG. + epoch_seed = int(np.random.SeedSequence([self.seed, epoch]).generate_state(1, dtype=np.uint64)[0]) + return torch.Generator().manual_seed(epoch_seed) + + def _frame_index(self, position: int) -> int: + episode = int(np.searchsorted(self._cum_lengths, position, side="right")) + position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0) + return int(self._starts[episode]) + position_in_episode def __iter__(self) -> Iterator[int]: + # Advance epoch state eagerly, not on first consumption of the generator. + epoch, start = self._epoch, self._start_index + self._epoch += 1 + self._start_index = 0 + return self._iter_epoch(epoch, start) + + def _iter_epoch(self, epoch: int, start: int) -> Iterator[int]: if self.shuffle: - for i in torch.randperm(len(self.indices), generator=self.generator): - yield self.indices[i] + order = torch.randperm(self._num_frames, generator=self._epoch_generator(epoch)) + for k in range(start, self._num_frames): + yield self._frame_index(int(order[k])) else: - for i in self.indices: - yield i + for k in range(start, self._num_frames): + yield self._frame_index(k) def __len__(self) -> int: - return len(self.indices) + return self._num_frames + + +def compute_sampler_state(step: int, num_frames: int, batch_size: int, num_processes: int) -> dict: + """Map an optimization step to an `EpisodeAwareSampler` state for sample-exact resume. + + Under accelerate's batch sharding, one step consumes `batch_size * num_processes` sampler + positions and each rank sees `ceil(ceil(num_frames / batch_size) / num_processes)` batches + per epoch (`even_batches` padding included). The start index provably stays below + `num_frames`; the `min` is defensive. + + Assumptions (resume is only sample-exact when they hold): + - `num_processes` and `batch_size` match the run that wrote the checkpoint. Both scale how + many positions a step consumes, so the epoch/offset are wrong if either changed. The + caller passes the checkpoint's `num_processes` and `batch_size` and warns on a mismatch. + - accelerate uses `even_batches=True` (its default). The `ceil(... / num_processes)` term + mirrors that padding; with `even_batches=False` the per-epoch batch count differs and + the boundary is off. + """ + batches_per_epoch = math.ceil(math.ceil(num_frames / batch_size) / num_processes) + epoch, batches_into_epoch = divmod(step, batches_per_epoch) + start_index = min(batches_into_epoch * batch_size * num_processes, num_frames) + return {"epoch": epoch, "start_index": start_index} diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a35d4229d..70a5e9e9d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -36,6 +36,8 @@ from tqdm import tqdm from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, + load_training_num_processes, load_training_state, save_checkpoint, update_last_checkpoint, @@ -43,7 +45,7 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import EpisodeAwareSampler, make_dataset +from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -237,18 +239,17 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - # Dataset loading synchronization: each node's local main process downloads first to avoid - # race conditions (the global main process only exists on node 0, so gating on it would let - # all ranks of the other nodes download and build the Arrow cache concurrently). - if accelerator.is_local_main_process: - if is_main_process: - logging.info("Creating dataset") + # Dataset loading synchronization: the global main process downloads once to the shared + # dataset root, then a barrier lets every other rank read the already-populated copy. + # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. + if is_main_process: + logging.info("Creating dataset") dataset = make_dataset(cfg) accelerator.wait_for_everyone() - # Now all other processes can safely load the dataset from the local cache - if not accelerator.is_local_main_process: + # Other ranks read from the shared copy populated by the main process. + if not is_main_process: dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. @@ -392,22 +393,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # create dataloader for offline training - if hasattr(active_cfg, "drop_n_last_frames"): + if not cfg.dataset.streaming: + # All non-streaming (map-style) datasets use EpisodeAwareSampler. + # The order is a pure function of (seed, epoch), so every rank independently produces the + # same permutation. accelerate then shards it disjointly across ranks via BatchSamplerShard + # without needing a `generator` attribute to synchronize an RNG, and resume is sample-exact. shuffle = False - # A dedicated generator (rather than the global torch RNG) lets accelerator.prepare - # synchronize the shuffle permutation across ranks, keeping batch shards disjoint even - # when ranks consume the global RNG asymmetrically (e.g. eval on the main process only). - sampler_generator = torch.Generator() - if cfg.seed is not None: - sampler_generator.manual_seed(cfg.seed) sampler = EpisodeAwareSampler( dataset.meta.episodes["dataset_from_index"], dataset.meta.episodes["dataset_to_index"], episode_indices_to_use=dataset.episodes, - drop_n_last_frames=active_cfg.drop_n_last_frames, + drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), shuffle=True, - generator=sampler_generator, + seed=cfg.seed if cfg.seed is not None else 0, ) + if cfg.resume and step > 0: + # The resume offset depends on the (num_processes, batch_size) that produced `step`, so + # use the values recorded in the checkpoint (falling back to the current ones for older + # ckpts that did not store them). + saved_num_processes = load_training_num_processes(cfg.checkpoint_path) + saved_batch_size = load_training_batch_size(cfg.checkpoint_path) + ckpt_num_processes = saved_num_processes or accelerator.num_processes + ckpt_batch_size = saved_batch_size or cfg.batch_size + if is_main_process and saved_num_processes not in (None, accelerator.num_processes): + logging.warning( + f"Resuming with num_processes={accelerator.num_processes} but the checkpoint was " + f"written with num_processes={saved_num_processes}. The data order resumes at the " + "right epoch/offset, but per-rank sample-exactness requires the same world size." + ) + if is_main_process and saved_batch_size not in (None, cfg.batch_size): + logging.warning( + f"Resuming with batch_size={cfg.batch_size} but the checkpoint was written with " + f"batch_size={saved_batch_size}. The data order resumes at the right epoch/offset, " + "but per-rank sample-exactness requires the same batch size." + ) + sampler_state = compute_sampler_state(step, len(sampler), ckpt_batch_size, ckpt_num_processes) + sampler.load_state_dict(sampler_state) + if is_main_process: + logging.info( + f"Resuming data order at epoch {sampler_state['epoch']}, " + f"sample {sampler_state['start_index']}" + ) else: shuffle = True sampler = None @@ -544,6 +570,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): scheduler=lr_scheduler, preprocessor=preprocessor, postprocessor=postprocessor, + num_processes=accelerator.num_processes, + batch_size=cfg.batch_size, ) update_last_checkpoint(checkpoint_dir) if wandb_logger: diff --git a/tests/datasets/test_sampler.py b/tests/datasets/test_sampler.py index 95429c7ec..7a5fc0fe0 100644 --- a/tests/datasets/test_sampler.py +++ b/tests/datasets/test_sampler.py @@ -114,28 +114,17 @@ def test_shuffle(): assert set(sampler) == {0, 1, 2, 3, 4, 5} -def test_shuffle_with_generator_is_deterministic(): - # Two samplers shuffling with same-seed generators must yield identical permutations. - # This is what keeps batch shards disjoint across ranks in distributed training, where - # accelerate synchronizes the sampler's generator state instead of the global torch RNG. - sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - assert list(sampler_a) == list(sampler_b) - +def test_shuffle_is_reproducible_across_instances(): + # The order is a pure function of (seed, epoch), so two fresh samplers (e.g. two ranks) + # produce the same permutation without any generator synchronization. + sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) + sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) + epoch_0 = list(sampler_a) + assert list(sampler_b) == epoch_0 # Desyncing the global RNG must not affect the permutation. - sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42)) - order_before = list(sampler_c) - sampler_c.generator.manual_seed(42) + sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, seed=42) torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would - assert list(sampler_c) == order_before - - -def test_generator_attribute_defaults_to_none(): - # accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`, - # so the attribute must exist even when no generator is passed. - sampler = EpisodeAwareSampler([0], [6], shuffle=True) - assert sampler.generator is None - assert set(sampler) == {0, 1, 2, 3, 4, 5} + assert list(sampler_c) == epoch_0 def test_negative_drop_first_frames_raises(): @@ -161,3 +150,87 @@ def test_partial_episode_drop_warns(caplog): # Episode 0 is skipped (1 frame, drop 1), Episode 1 keeps frames 2-5 assert sampler.indices == [2, 3, 4, 5] assert "Episode 0" in caplog.text + + +# --- seeded (seed, epoch) shuffling, resume, and state --- + +from lerobot.datasets.sampler import compute_sampler_state # noqa: E402 + +EPISODE_BOUNDS = ([0, 2, 3], [2, 3, 6]) # episodes of 2, 1 and 3 frames + + +@pytest.mark.parametrize("num_frames", [1, 2, 3, 37, 64, 100]) +def test_deterministic_sampler_shuffle_is_permutation(num_frames): + for seed in (0, 1, 1234): + sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=seed) + assert sorted(sampler) == list(range(num_frames)) + + +def test_deterministic_sampler_epochs_reproduce_and_differ(): + sampler_a = EpisodeAwareSampler([0], [100], shuffle=True, seed=42) + sampler_b = EpisodeAwareSampler([0], [100], shuffle=True, seed=42) + epoch_0 = list(sampler_a) + assert list(sampler_b) == epoch_0 # same (seed, epoch) -> same order on any process + epoch_1 = list(sampler_a) # __iter__ auto-advances the epoch + assert epoch_1 != epoch_0 + assert sorted(epoch_1) == sorted(epoch_0) + sampler_a.set_epoch(0) + assert list(sampler_a) == epoch_0 + assert list(EpisodeAwareSampler([0], [100], shuffle=True, seed=7)) != epoch_0 + + +def test_deterministic_sampler_resume_mid_epoch(): + reference = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + epoch_0 = list(reference) + epoch_1 = list(reference) + for start in (0, 1, 4, len(epoch_0)): + resumed = EpisodeAwareSampler(*EPISODE_BOUNDS, shuffle=True, seed=42) + resumed.load_state_dict({"epoch": 0, "start_index": start}) + assert list(resumed) == epoch_0[start:] + # the resumed sampler continues into the same epoch 1 as the uninterrupted one + assert list(resumed) == epoch_1 + + +def test_deterministic_sampler_construction_stores_only_boundaries(): + # Construction is O(num_episodes), not O(num_frames): a million-frame single episode + # instantiates from just its boundaries without materializing a per-frame index list. + num_frames = 1_000_000 + sampler = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + assert len(sampler) == num_frames + assert sampler._starts.shape == (1,) and sampler._cum_lengths.shape == (1,) + + +def test_deterministic_sampler_resume_is_exact_at_scale(): + # Seeded randperm makes resume sample-exact at non-trivial sizes: regenerating the epoch's + # permutation and slicing from the saved offset reproduces the remaining order exactly. + num_frames = 100_000 + reference = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + epoch_0 = list(reference) + assert sorted(epoch_0) == list(range(num_frames)) + start = num_frames - 5 + resumed = EpisodeAwareSampler([0], [num_frames], shuffle=True, seed=0) + resumed.load_state_dict({"epoch": 0, "start_index": start}) + assert list(resumed) == epoch_0[start:] + + +def test_compute_sampler_state(): + # 100 frames, batch 10, 2 ranks -> 10 underlying batches, 5 per rank per epoch. + assert compute_sampler_state(step=0, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 0, + "start_index": 0, + } + # step 7 -> epoch 1, 2 per-rank batches in = 2 * 10 * 2 = 40 samples in + assert compute_sampler_state(step=7, num_frames=100, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 40, + } + # uneven epoch: 95 frames -> 10 underlying batches (last short), still 5 per rank + assert compute_sampler_state(step=12, num_frames=95, batch_size=10, num_processes=2) == { + "epoch": 2, + "start_index": 40, + } + # uneven sharding: 105 frames -> 11 underlying batches, 6 per rank (even_batches pads) + assert compute_sampler_state(step=11, num_frames=105, batch_size=10, num_processes=2) == { + "epoch": 1, + "start_index": 100, + } diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 8e5b3f167..c171763c2 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -20,6 +20,8 @@ from unittest.mock import Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, get_step_identifier, + load_training_batch_size, + load_training_num_processes, load_training_state, load_training_step, save_checkpoint, @@ -63,6 +65,28 @@ def test_load_training_step(tmp_path): assert loaded_step == step +def test_save_training_state_records_num_processes(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, num_processes=4) + assert load_training_num_processes(tmp_path) == 4 + + +def test_load_training_num_processes_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the world size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_num_processes(tmp_path) is None + + +def test_save_training_state_records_batch_size(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler, batch_size=32) + assert load_training_batch_size(tmp_path) == 32 + + +def test_load_training_batch_size_absent_returns_none(tmp_path, optimizer, scheduler): + # Checkpoints written before the batch size was recorded must still load (back-compat). + save_training_state(tmp_path, 10, optimizer, scheduler) + assert load_training_batch_size(tmp_path) is None + + def test_update_last_checkpoint(tmp_path): checkpoint = tmp_path / "0005" checkpoint.mkdir() From 02b315ab6a6709c04a4d19b8cacfc1988c247d37 Mon Sep 17 00:00:00 2001 From: Nikodem Bartnik <39432165+NikodemBartnik@users.noreply.github.com> Date: Fri, 12 Jun 2026 13:26:52 +0200 Subject: [PATCH 11/27] Docs/model card improvements (#3634) * update policy deployment instruction with rollout * add port and fix formatting * add more base models to generate model card * updated and extended model descriptions * fix bug * improved and extended structure * exclude the templates from config * add images and visualize dataset button * add all policies we have docs for * remove policies without the docs * new fields, improved examples --- .pre-commit-config.yaml | 3 + src/lerobot/policies/pretrained.py | 90 ++++++- .../templates/lerobot_modelcard_template.md | 252 ++++++++++++++---- 3 files changed, 282 insertions(+), 63 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dff7416f4..8ae913e4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,6 +65,9 @@ repos: name: Format Markdown with Prettier types_or: [markdown, mdx] args: [--prose-wrap=preserve] + # Jinja2 model-card templates use a .md extension but contain {% ... %} / + # {{ ... }} tags that prettier's Markdown formatter mangles (e.g. table loops). + exclude: ^src/lerobot/templates/.*\.md$ ##### Security ##### - repo: https://github.com/gitleaks/gitleaks diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 724f920f3..a69487f3f 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -29,6 +29,7 @@ from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn +from lerobot.__version__ import __version__ from lerobot.configs import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.utils.hub import HubMixin @@ -38,6 +39,67 @@ from .utils import log_model_loading_keys T = TypeVar("T", bound="PreTrainedPolicy") +def _build_card_context( + cfg: TrainPipelineConfig | None, + dataset_repo_id: str | None, + input_features: dict | None, + output_features: dict | None, +) -> dict: + """Collect optional data for the model-card template. + + Returns plain values only (no Markdown) — the template in + ``lerobot/templates/lerobot_modelcard_template.md`` decides how and whether to show + each one. Everything is best-effort: anything unavailable is left empty/None and the + template simply skips that section, so this never breaks a Hub push. + """ + context = { + "training": None, + "input_features": input_features or {}, + "output_features": output_features or {}, + "dataset": None, + "robot_type": None, + "cameras": [], + } + + if cfg is not None: + optimizer = getattr(cfg, "optimizer", None) + context["training"] = { + "steps": cfg.steps, + "batch_size": cfg.batch_size, + "seed": cfg.seed, + "optimizer": getattr(optimizer, "type", None) if optimizer else None, + "lr": getattr(optimizer, "lr", None) if optimizer else None, + "lerobot_version": __version__, + } + + if dataset_repo_id: + dataset_cfg = getattr(cfg, "dataset", None) + try: + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata + + meta = LeRobotDatasetMetadata( + dataset_repo_id, + root=getattr(dataset_cfg, "root", None), + revision=getattr(dataset_cfg, "revision", None), + ) + context["dataset"] = { + "repo_id": dataset_repo_id, + "episodes": meta.total_episodes, + "frames": meta.total_frames, + "fps": meta.fps, + "tasks": [str(task) for task in meta.tasks.index], + } + context["robot_type"] = meta.robot_type + context["cameras"] = [key.split(".")[-1] for key in meta.camera_keys] + except Exception as e: # noqa: BLE001 — dataset details are optional, never fail the push + logging.warning( + f"Could not load dataset metadata for '{dataset_repo_id}'; those sections will be " + f"omitted from the model card. ({e})" + ) + + return context + + class ActionSelectKwargs(TypedDict, total=False): noise: Tensor | None @@ -228,7 +290,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors card = self.generate_model_card( - cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags + cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg ) card.save(str(saved_path / "README.md")) @@ -246,9 +308,20 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): logging.info(f"Model pushed to {commit_info.repo_url.url}") def generate_model_card( - self, dataset_repo_id: str, model_type: str, license: str | None, tags: list[str] | None + self, + dataset_repo_id: str, + model_type: str, + license: str | None, + tags: list[str] | None, + cfg: TrainPipelineConfig | None = None, ) -> ModelCard: - base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model + base_model_mapping = { + "smolvla": "lerobot/smolvla_base", + "pi0": "lerobot/pi0_base", + "pi05": "lerobot/pi05_base", + "pi0_fast": "lerobot/pi0fast-base", + "xvla": "lerobot/xvla-base", + } card_data = ModelCardData( license=license or "apache-2.0", @@ -257,13 +330,20 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): tags=list(set(tags or []).union({"robotics", "lerobot", model_type})), model_name=model_type, datasets=dataset_repo_id, - base_model=base_model, + base_model=base_model_mapping.get(model_type), ) + context = _build_card_context( + cfg, dataset_repo_id, self.config.input_features, self.config.output_features + ) + # Used by the template to pre-fill commands and the "Fine-tuned from" line. + context["policy_repo_id"] = getattr(self.config, "repo_id", None) + context["base_model"] = base_model_mapping.get(model_type) + template_card = ( files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text(encoding="utf-8") ) - card = ModelCard.from_template(card_data, template_str=template_card) + card = ModelCard.from_template(card_data, template_str=template_card, **context) card.validate() return card diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index b93e83b6e..6ecda06c9 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -13,77 +13,213 @@ [SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware. {% elif model_name == "act" %} [Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates. -{% elif model_name == "tdmpc" %} -[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function. {% elif model_name == "diffusion" %} [Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. -{% elif model_name == "vqbet" %} -[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. {% elif model_name == "pi0" %} -**π₀ (Pi0)** - -π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. - -**Model Overview** - -π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. - -For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0). +[π₀ (Pi0)](https://www.physicalintelligence.company/blog/pi0) is a general-purpose robot foundation model from Physical Intelligence: a generalist Vision-Language-Action policy that understands visual inputs, interprets natural language instructions, and controls a variety of different robots across diverse tasks. The LeRobot implementation is adapted from their open-source OpenPI repository. {% elif model_name == "pi05" %} -**π₀.₅ (Pi05) Policy** - -π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. - -**Model Overview** - -π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training. - -For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05). +[π₀.₅ (Pi05)](https://www.physicalintelligence.company/blog/pi05) is a Vision-Language-Action model from Physical Intelligence designed for open-world generalization: it evolves π₀ to generalize to entirely new environments and situations that were never seen during training. The LeRobot implementation is adapted from their open-source OpenPI repository. +{% elif model_name == "molmoact2" %} +[MolmoAct2](https://allenai.org/blog/molmoact2) is an open robotics foundation model from the Allen Institute for AI (Ai2) that maps camera images and language instructions to robot action chunks. The LeRobot implementation supports training and evaluation of the regular MolmoAct2 model. +{% elif model_name == "vla_jepa" %} +[VLA-JEPA](https://arxiv.org/abs/2602.10098) is a Vision-Language-Action model that combines a Qwen3-VL language backbone with a self-supervised video world model (V-JEPA2) and a flow-matching DiT action head. {% elif model_name == "gaussian_actor" %} This is a Gaussian Actor policy (Gaussian policy with a tanh squash) — the policy-side component used by [Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) and related maximum-entropy continuous-control algorithms. +{% elif model_name == "pi0_fast" %} +[π₀-FAST (Pi0-FAST)](https://www.physicalintelligence.company/research/fast) is a Vision-Language-Action model for general robot control, from Physical Intelligence. It models continuous robot actions with autoregressive next-token prediction using FAST (Frequency-space Action Sequence Tokenization), training up to 5x faster than diffusion-based π₀. +{% elif model_name == "eo1" %} +[EO-1](https://huggingface.co/papers/2508.21112) is a Vision-Language-Action model for general robot control. It pairs a Qwen2.5-VL backbone for vision-language understanding with a continuous flow-matching action head that denoises action chunks. +{% elif model_name == "groot" %} +[GR00T N1.5](https://github.com/NVIDIA/Isaac-GR00T) is an open, cross-embodiment foundation model from NVIDIA for generalized humanoid robot reasoning and skills. It takes language and images as input and uses a flow-matching action transformer to predict actions conditioned on vision, language, and proprioception. +{% elif model_name == "multi_task_dit" %} +[Multi-Task Diffusion Transformer (DiT)](https://huggingface.co/papers/2507.05331) extends Diffusion Policy with a large Diffusion Transformer and text + vision conditioning for multi-task robot learning. It supports both diffusion and flow-matching objectives and reaches high dexterity with only ~450M parameters. +{% elif model_name == "wall_x" %} +[WALL-OSS](https://huggingface.co/papers/2509.11766) is an open-source foundation model for embodied intelligence from XSquare Robot. Built on Qwen2.5-VL, it uses a tightly-coupled multimodal architecture with flow matching to unify semantic reasoning and high-frequency action generation for cross-embodiment control. +{% elif model_name == "xvla" %} +[X-VLA](https://huggingface.co/papers/2510.10274) is a soft-prompted, flow-matching Vision-Language-Action framework that treats each robot or hardware setup as a "task" encoded with a small set of learnable Soft Prompt embeddings, letting a single model reconcile diverse robot morphologies, sensors, and action spaces. {% else %} -_Model type not recognized — please update this template._ +This is a **{{ model_name }}** policy trained with [LeRobot](https://github.com/huggingface/lerobot). +{% endif %} +{% set diagrams = { + "smolvla": "https://cdn-uploads.huggingface.co/production/uploads/640e21ef3c82bd463ee5a76d/aooU0a3DMtYmy_1IWMaIM.png", + "pi0": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pi0%20(1).png", + "pi0_fast": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-pifast.png", + "eo1": "https://huggingface.co/datasets/HaomingSong/lerobot-documentation-images/resolve/main/lerobot/eo_pipeline.png", + "groot": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/lerobot-groot-paper1%20(1).png", + "wall_x": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/walloss-lerobot-paper.png", + "xvla": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png" +} %} +{% if diagrams.get(model_name) %} +

+ {{ model_name }} architecture +

{% endif %} + + This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot). -See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index). - ---- - -## How to Get Started with the Model - -For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy). -Below is the short version on how to train and run inference/eval: - -### Train from scratch - -```bash -lerobot-train \ - --dataset.repo_id=${HF_USER}/ \ - --policy.type=act \ - --output_dir=outputs/train/ \ - --job_name=lerobot_training \ - --policy.device=cuda \ - --policy.repo_id=${HF_USER}/ - --wandb.enable=true -``` - -_Writes checkpoints to `outputs/train//checkpoints/`._ - -### Evaluate the policy/run inference - -```bash -lerobot-record \ - --robot.type=so100_follower \ - --dataset.repo_id=/eval_ \ - --policy.path=/ \ - --episodes=10 -``` - -Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint. +{% set policy_docs = { + "act": "act", + "smolvla": "smolvla", + "pi0": "pi0", + "pi0_fast": "pi0fast", + "pi05": "pi05", + "molmoact2": "molmoact2", + "vla_jepa": "vla_jepa", + "eo1": "eo1", + "groot": "groot", + "xvla": "xvla", + "multi_task_dit": "multi_task_dit", + "wall_x": "walloss" +} %} +{% if policy_docs.get(model_name) %}Learn how to train and run it in the [LeRobot {{ model_name }} guide](https://huggingface.co/docs/lerobot/main/en/{{ policy_docs[model_name] }}), or browse the [full documentation](https://huggingface.co/docs/lerobot/index). +{% else %}See the [full LeRobot documentation](https://huggingface.co/docs/lerobot/index). +{% endif %} --- ## Model Details - **License:** {{ license | default("\[More Information Needed]", true) }} +{% if base_model %}- **Fine-tuned from:** [{{ base_model }}](https://huggingface.co/{{ base_model }}) +{% endif %}{% if robot_type %}- **Robot type:** `{{ robot_type }}` +{% endif %}{% if cameras %}- **Cameras:** {% for camera in cameras %}`{{ camera }}`{% if not loop.last %}, {% endif %}{% endfor %} +{% endif %} +{% if input_features or output_features %} +## Inputs & Outputs + +The policy consumes these observation features and produces these action features. +{% if input_features %} +**Inputs** + +| Feature | Type | Shape | +| --- | --- | --- | +{% for name, feature in input_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` | +{% endfor %}{% endif %}{% if output_features %} +**Outputs** + +| Feature | Type | Shape | +| --- | --- | --- | +{% for name, feature in output_features.items() %}| `{{ name }}` | {{ feature.type.value }} | `{{ feature.shape }}` | +{% endfor %}{% endif %}{% endif %} +{% if dataset %} +## Training Dataset + +- **Repository:** [{{ dataset.repo_id }}](https://huggingface.co/datasets/{{ dataset.repo_id }}) +- **Episodes:** {{ dataset.episodes }} +- **Frames:** {{ dataset.frames }} +- **Frame rate:** {{ dataset.fps }} FPS +{% if dataset.tasks %}- **Task(s):** {% for task in dataset.tasks %}"{{ task }}"{% if not loop.last %}, {% endif %}{% endfor %} +{% endif %} + + + + +{% endif %} +{% if training %} +## Training Configuration + +| Setting | Value | +| --- | --- | +| Training steps | {{ training.steps }} | +| Batch size | {{ training.batch_size }} | +{% if training.optimizer %}| Optimizer | {{ training.optimizer }} | +{% endif %}{% if training.lr %}| Learning rate | {{ training.lr }} | +{% endif %}{% if training.seed is not none %}| Seed | {{ training.seed }} | +{% endif %}| LeRobot version | {{ training.lerobot_version }} | +{% endif %} +--- + +## How to Get Started with the Model + +New to LeRobot? These guides cover the full workflow: + +- **[Install LeRobot](https://huggingface.co/docs/lerobot/main/en/installation)** — set up the `lerobot` package. +- **[Hardware setup](https://huggingface.co/docs/lerobot/main/en/hardware_guide)** — assemble, wire, and calibrate your robot and cameras. +- **[Record data & train a policy](https://huggingface.co/docs/lerobot/en/il_robots)** — the end-to-end imitation-learning walkthrough. +- **[CLI cheat-sheet](https://huggingface.co/docs/lerobot/main/en/cheat-sheet)** — quick reference for the `lerobot-*` commands. + +The short version to run and train this policy: + +### Run the policy on your robot + +```bash +lerobot-rollout \ + --strategy.type=base \ + --robot.type={{ robot_type | default("", true) }} \ + --robot.port= \ + --robot.cameras="{ : {type: opencv, index_or_path: , width: 640, height: 480, fps: 30}, : {type: opencv, index_or_path: , width: 640, height: 480, fps: 30}}" \ + --policy.path={{ policy_repo_id | default("/", true) }} \ + --task="{% if dataset and dataset.tasks %}{{ dataset.tasks[0] }}{% else %}{% endif %}" \ + --duration=60 +``` + +Replace the remaining `<...>` placeholders with your own values: `--robot.port` and the camera names/indices are specific to your machine, and the camera names must match the observation keys this policy was trained on. + +When `--strategy.type=base` is used the script doesn't record the episodes. Skipping duration will make the policy run indefinitely. For more information look at [rollout documentation](https://huggingface.co/docs/lerobot/main/en/inference). + +{% if base_model %}### Train your own policy + +This policy type is usually fine-tuned from the pretrained base model [{{ base_model }}](https://huggingface.co/{{ base_model }}): + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/ \ + --policy.path={{ base_model }} \ + --output_dir=outputs/train/ \ + --job_name=lerobot_training \ + --policy.device=cuda \ + --policy.repo_id=${HF_USER}/ \ + --wandb.enable=true +``` +{% else %}### Train your own policy + +```bash +lerobot-train \ + --dataset.repo_id=${HF_USER}/ \ + --policy.type={{ model_name }} \ + --output_dir=outputs/train/ \ + --job_name=lerobot_training \ + --policy.device=cuda \ + --policy.repo_id=${HF_USER}/ \ + --wandb.enable=true +``` +{% endif %} +_Writes checkpoints to `outputs/train//checkpoints/`._ + +--- + +## Evaluation + + + +_No evaluation results have been provided for this policy yet._ + +--- + +## Citation + +If you use this policy, please cite the method linked in the description above, along with LeRobot: + +```bibtex +@misc{cadene2024lerobot, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, + title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch}, + howpublished = "\url{https://github.com/huggingface/lerobot}", + year = {2024} +} +``` From cec8ee0be6e1006f9ecbd829fea59d051a11c515 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:12:33 +0200 Subject: [PATCH 12/27] feat: language annotation pipeline (#3471) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Steerable annotation pipeline (lerobot-annotate) that populates the language_persistent and language_events columns introduced in PR 1 (#3467) directly into data/chunk-*/file-*.parquet. This is PR 2 of the three-PR plan: PR 1 (Add extensive language support #3467): schema + DSL + rendering, base of this PR PR 2 (this PR): annotation pipeline writing into PR 1's columns PR 3: model with language prediction and runtime A VLM (Qwen-VL family, served on vLLM) watches each episode's video and emits grounded language annotations: subtasks, plans, memory, task rephrasings, interjections + speech, and per-camera VQA. The pipeline is built for production annotation at scale — single-camera grounding, embedded-frame inputs, a describe-then-segment grounding flow, and a deterministic full-episode coverage guarantee — informed by Scale's dense-captioning findings (representation > sampling, rules > reasoning, model capacity is the biggest lever, two-pass systems compound errors) --- Makefile | 6 + docs/source/_toctree.yml | 2 + docs/source/annotation_pipeline.mdx | 291 +++++++ examples/annotations/run_hf_job.py | 77 ++ pyproject.toml | 18 +- src/lerobot/annotations/__init__.py | 15 + .../steerable_pipeline/__init__.py | 36 + .../annotations/steerable_pipeline/config.py | 211 +++++ .../steerable_pipeline/executor.py | 253 ++++++ .../annotations/steerable_pipeline/frames.py | 481 +++++++++++ .../steerable_pipeline/modules/__init__.py | 25 + .../steerable_pipeline/modules/general_vqa.py | 248 ++++++ .../modules/interjections_and_speech.py | 211 +++++ .../modules/plan_subtasks_memory.py | 780 ++++++++++++++++++ .../steerable_pipeline/prompts/__init__.py | 33 + .../prompts/interjections_initial_speech.txt | 12 + .../prompts/interjections_interjection.txt | 46 ++ .../prompts/plan_memory.txt | 36 + .../prompts/plan_subtask_describe.txt | 27 + .../prompts/plan_subtasks.txt | 112 +++ .../prompts/plan_task_aug_axes.txt | 67 ++ .../prompts/plan_task_rephrasings.txt | 32 + .../prompts/plan_video_task.txt | 17 + .../steerable_pipeline/prompts/vqa.txt | 32 + .../annotations/steerable_pipeline/reader.py | 216 +++++ .../annotations/steerable_pipeline/staging.py | 92 +++ .../steerable_pipeline/validator.py | 332 ++++++++ .../steerable_pipeline/vlm_client.py | 617 ++++++++++++++ .../annotations/steerable_pipeline/writer.py | 341 ++++++++ src/lerobot/scripts/lerobot_annotate.py | 206 +++++ tests/annotations/__init__.py | 0 tests/annotations/_helpers.py | 58 ++ tests/annotations/conftest.py | 58 ++ tests/annotations/run_e2e_smoke.py | 116 +++ tests/annotations/test_frames.py | 246 ++++++ tests/annotations/test_modules.py | 390 +++++++++ .../test_pipeline_recipe_render.py | 183 ++++ tests/annotations/test_validator.py | 133 +++ tests/annotations/test_vlm_client.py | 41 + tests/annotations/test_writer.py | 357 ++++++++ tests/fixtures/dataset_factories.py | 61 ++ tests/scripts/test_lerobot_annotate.py | 86 ++ uv.lock | 124 ++- 43 files changed, 6723 insertions(+), 2 deletions(-) create mode 100644 docs/source/annotation_pipeline.mdx create mode 100644 examples/annotations/run_hf_job.py create mode 100644 src/lerobot/annotations/__init__.py create mode 100644 src/lerobot/annotations/steerable_pipeline/__init__.py create mode 100644 src/lerobot/annotations/steerable_pipeline/config.py create mode 100644 src/lerobot/annotations/steerable_pipeline/executor.py create mode 100644 src/lerobot/annotations/steerable_pipeline/frames.py create mode 100644 src/lerobot/annotations/steerable_pipeline/modules/__init__.py create mode 100644 src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py create mode 100644 src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py create mode 100644 src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/__init__.py create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_task_aug_axes.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_task_rephrasings.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/plan_video_task.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/prompts/vqa.txt create mode 100644 src/lerobot/annotations/steerable_pipeline/reader.py create mode 100644 src/lerobot/annotations/steerable_pipeline/staging.py create mode 100644 src/lerobot/annotations/steerable_pipeline/validator.py create mode 100644 src/lerobot/annotations/steerable_pipeline/vlm_client.py create mode 100644 src/lerobot/annotations/steerable_pipeline/writer.py create mode 100644 src/lerobot/scripts/lerobot_annotate.py create mode 100644 tests/annotations/__init__.py create mode 100644 tests/annotations/_helpers.py create mode 100644 tests/annotations/conftest.py create mode 100644 tests/annotations/run_e2e_smoke.py create mode 100644 tests/annotations/test_frames.py create mode 100644 tests/annotations/test_modules.py create mode 100644 tests/annotations/test_pipeline_recipe_render.py create mode 100644 tests/annotations/test_validator.py create mode 100644 tests/annotations/test_vlm_client.py create mode 100644 tests/annotations/test_writer.py create mode 100644 tests/scripts/test_lerobot_annotate.py diff --git a/Makefile b/Makefile index e02f02403..d3987101f 100644 --- a/Makefile +++ b/Makefile @@ -178,3 +178,9 @@ test-smolvla-ete-eval: --env.episode_length=5 \ --eval.n_episodes=1 \ --eval.batch_size=1 + +# E2E annotation pipeline smoke test against a tiny in-memory fixture +# dataset. Opt-in (not part of `make test-end-to-end`) and uses a stub VLM +# backend, so it does not require a real model checkpoint or GPU. +annotation-e2e: + uv run python -m tests.annotations.run_e2e_smoke diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0d4e36172..5d847a94d 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -45,6 +45,8 @@ title: Language Columns and Recipes - local: tools title: Tools + - local: annotation_pipeline + title: Annotation Pipeline - local: video_encoding_parameters title: Video encoding parameters - local: streaming_video_encoding diff --git a/docs/source/annotation_pipeline.mdx b/docs/source/annotation_pipeline.mdx new file mode 100644 index 000000000..02658ec9a --- /dev/null +++ b/docs/source/annotation_pipeline.mdx @@ -0,0 +1,291 @@ +# Annotation Pipeline + +`lerobot-annotate` watches each episode's video with a vision-language +model (VLM) and writes natural-language annotations back into your +dataset. It fills the two language columns from the +[Language Columns and Recipes](./language_and_recipes) page — +`language_persistent` and `language_events` — straight into +`data/chunk-*/file-*.parquet`. + +In short: point it at a LeRobot dataset, and it adds subtasks, plans, +memory, interjections, speech, and visual Q&A that a policy can be +trained on. + +## How it fits together + +```text + your dataset lerobot-annotate + (LeRobot v3.1) + │ + ▼ + ┌─────────────────────────────────────────────────────┐ + │ read episodes │ + └──────────────────────────┬──────────────────────────┘ + │ + ┌────────────────────┼────────────────────┐ + ▼ ▼ ▼ + ┌──────────┐ ┌───────────────┐ ┌──────────┐ one shared Qwen-VL + │ plan │ │ interjections │ │ vqa │ ◀── server (vLLM, OpenAI + └────┬─────┘ └───────┬───────┘ └────┬─────┘ API) drives all three + └────────────────────┼─────────────────────┘ + │ each module stages raw JSONL + ▼ into .annotate_staging/ + ┌─────────────────┐ + │ validator │ ◀── checks everything + └────────┬────────┘ + ▼ + ┌─────────────────┐ + │ writer │ + └────────┬────────┘ + ▼ + data/chunk-*/file-*.parquet + (+ meta/info.json tools) +``` + +Three modules (`plan`, `interjections`, `vqa`) all talk to **one** shared +VLM. Each module stages its output to disk, a validator checks it, and a +single writer rewrites the dataset shards in place. + +## What the pipeline produces + +Each module emits a few kinds of annotation ("styles"), routed to one of +the two language columns: + +| Style / atom | Column | Module | +| ------------------------------------------- | --------------------- | --------------- | +| `subtask` (Pi0.7-style "how, not what") | `language_persistent` | `plan` | +| `plan` (initial + refresh on interjection) | `language_persistent` | `plan` | +| `memory` (MEM-style compression) | `language_persistent` | `plan` | +| `task_aug` (rephrasings of the task) | `language_persistent` | `plan` | +| `interjection` | `language_events` | `interjections` | +| speech tool-call atom (`style=null`, `say`) | `language_events` | `interjections` | +| `vqa` (user / assistant pair) | `language_events` | `vqa` | + +### How subtasks are generated + +The `plan` module doesn't ask the VLM for subtasks in one shot. Instead +it uses a two-step **describe → segment** flow: + +1. **Describe** — the VLM narrates only what it actually sees in the + chosen camera (no guessing about the task). +2. **Segment** — that description is fed back in, and the VLM splits the + episode into consecutive atomic subtasks. + +Both passes see the episode as **timestamped contact sheets** — frames +sampled at `frames_per_second` (0.5s by default) and packed into JPEG +grids with each frame's time burned into its corner, so the VLM cites +exact boundary times directly. This is far cheaper in vision tokens than +one image per frame, so the sampling can stay dense; episodes longer than +`max_frames_per_prompt` are split into windows at the same density and +merged. Both prompts also carry a causal **event-boundary** definition (a +new event starts when an object becomes held / is released / reaches a new +location / a lid changes state / contents move) to sharpen where cuts land. + +The resulting spans are then stitched into a gap-free, full-episode +cover, so **every frame has exactly one active subtask**. See +[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py) +for the production settings (single camera, timestamped contact sheets, +auto-windowed subtask generation). + +### Tools + +The writer does **not** add a `tools` column to the parquet. The tool +catalog lives in `meta/info.json["tools"]` instead (see [Tools](./tools)). +After every run, the pipeline makes sure the canonical `say` schema is in +that list, keeping any tools you declared beforehand. + +Want to add your own tool? Edit `meta/info.json["tools"]` directly — the +pipeline preserves whatever is already there. That makes the tool visible +to the chat template, so the model can learn to _generate_ the call. The +runtime layer that actually _executes_ a generated call (the `Tool` +protocol / `TOOL_REGISTRY` under `src/lerobot/tools/`) is not part of +this PR — the [Tools](./tools) doc marks those pieces as +not-yet-implemented. + +## Running on Hugging Face Jobs + +Annotation runs on [Hugging Face Jobs](https://huggingface.co/docs/hub/en/jobs). +The repo ships a launcher script you copy and tweak for your dataset: + +```bash +HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py +``` + +[`run_hf_job.py`](https://github.com/huggingface/lerobot/blob/main/examples/annotations/run_hf_job.py) +starts a single-GPU `h200` job (bump it to `h200x4` for big datasets) +that: + +1. installs `lerobot` (from `main`) plus the annotation extras, +2. boots one vLLM server per GPU (using the `vllm/vllm-openai` image) and + drives it over the OpenAI-compatible API, +3. runs the `plan` / `interjections` / `vqa` modules across the dataset + with `lerobot-annotate`, +4. with `--push_to_hub=true`, uploads the result to `--new_repo_id` (or + back to `--repo_id` in place if you leave that unset). + +To use a different dataset, model, or hub repo, edit the `CMD` block in +the script. Every flag there maps directly to a `lerobot-annotate` flag +(run `lerobot-annotate --help` for the full list). + +## Key options + +These are the flags you'll reach for most often. Run +`lerobot-annotate --help` for everything else; the defaults are tuned for +short manipulation episodes. + +### Dataset in / out + +| Flag | Default | What it does | +| ----------------- | ------- | ----------------------------------------------------------------------- | +| `--repo_id` | — | Hub dataset to annotate (downloaded if `--root` unset). | +| `--root` | — | Annotate a local dataset directory instead. | +| `--new_repo_id` | — | Push the result to a new repo (leaves the source repo untouched). | +| `--push_to_hub` | `false` | Upload after annotating (to `--new_repo_id`, else back to `--repo_id`). | +| `--only_episodes` | all | Annotate just these episode indices (handy for a test run). | +| `--seed` | `1729` | Seeds the RNGs that pick interjection timestamps + VQA question types. | + +### Which modules run + +Every module is on by default and can be toggled independently (set to +`false` to skip it, e.g. to iterate on one module at a time): + +| Flag | Default | Turns off | +| ------------------------- | ------- | ----------------------------------- | +| `--plan.enabled` | `true` | subtasks + plan + memory + task_aug | +| `--interjections.enabled` | `true` | interjections + speech atoms | +| `--vqa.enabled` | `true` | the VQA pairs | + +### The VLM (`--vlm.*`) + +| Flag | Default | What it does | +| -------------------------- | ------------------ | ----------------------------------------------------------------------------------- | +| `--vlm.model_id` | `Qwen/Qwen3.6-27B` | The model to serve and prompt. | +| `--vlm.camera_key` | first `images.*` | Which camera every prompt is grounded on. | +| `--vlm.serve_command` | auto | The exact `vllm serve …` command (set TP size, GPU memory, `--max-model-len` here). | +| `--vlm.parallel_servers` | `1` | Independent servers for round-robin routing (one per GPU). | +| `--vlm.num_gpus` | `0` | GPUs per server (`0` = one each). | +| `--vlm.client_concurrency` | `16` | In-flight requests across all servers. | +| `--vlm.max_new_tokens` | `512` | Generation cap per call. | +| `--vlm.temperature` | `0.2` | Sampling temperature. | + +### Subtasks / plan / memory (`--plan.*`) + +| Flag | Default | What it does | +| ------------------------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------- | +| `--plan.frames_per_second` | `2.0` | Frame sampling rate for the contact sheets (`2.0` = one frame every 0.5s). | +| `--plan.max_frames_per_prompt` | `60` | Frame budget per VLM call. Episodes whose sampling exceeds this are auto-windowed at the same density, then stitched. | +| `--plan.contact_sheet_columns` | `5` | Columns per contact-sheet grid (`contact_sheet_frames_per_sheet` tiles, time row-major). | +| `--plan.plan_max_steps` | `8` | Upper bound on subtasks per episode. | +| `--plan.subtask_describe_first` | `true` | Run the describe→segment grounding pass (best subtask quality; +1 call/episode). | +| `--plan.emit_plan` | `true` | Emit the numbered `plan` rows (`false` = subtasks + memory only). | +| `--plan.emit_memory` | `true` | Emit the `memory` rows (`false` = subtasks + plan only); symmetric to `emit_plan`. | +| `--plan.n_task_rephrasings` | `10` | How many `task_aug` rephrasings to emit (`0` disables). | +| `--plan.derive_task_from_video` | `if_short` | Use the dataset task as-is (`off`), only when it's missing/short (`if_short`), or always re-derive from video (`always`). | + +### Interjections + VQA + +| Flag | Default | What it does | +| ----------------------------------------------- | ------- | ---------------------------------------------------------- | +| `--interjections.max_interjections_per_episode` | `3` | Cap on interjection/speech pairs per episode. | +| `--vqa.vqa_emission_hz` | `1.0` | How often VQA pairs are emitted. | +| `--vqa.restrict_to_default_camera` | `false` | Ground VQA only on `--vlm.camera_key` (else every camera). | +| `--executor.episode_parallelism` | `16` | Episodes processed concurrently within each phase. | + +## Contributing new modules + +The pipeline is built to grow, and **contributions are very welcome** — +a brand-new module (say, trajectory traces or affordances), a new prompt +template, a smarter grounding flow, or quality fixes to the existing +`plan` / `interjections` / `vqa` modules. + +Every module lives under +`src/lerobot/annotations/steerable_pipeline/modules/`, shares the VLM +client and the keyframe cache, writes its raw output to the staging +tree, and plugs into the executor as its own phase. Got an idea? Open an +issue or PR on [the repo](https://github.com/huggingface/lerobot). + +## How recipes consume the output + +The annotations are meant to be read by recipes (see +[Language Columns and Recipes](./language_and_recipes)). Typically: + +- low-level / high-level / memory-update branches read + `subtask` / `plan` / `memory` from `language_persistent`. +- an interjection-response branch reads `interjection` events plus the + paired speech atom (merged into one assistant turn via `tool_calls_from`) + and the matching `plan` refresh at the same timestamp. +- a VQA branch reads the `(vqa, user)` and `(vqa, assistant)` pairs from + `language_events`. + +## Why state and events are split + +Two ideas shape the design: + +1. **Persistent state vs. exact events.** Persistent rows (`subtask`, + `plan`, `memory`) apply to the whole episode and answer "what's true + right now?". Event rows (`interjection`, `vqa`, speech) appear only on + the one frame whose timestamp matches. Timestamps are copied straight + from the source parquet — never recomputed in floating point. +2. **One VLM pass.** All three modules share a single VLM client (the + OpenAI-compatible client talking to the job's vLLM server), so you pay + for one model load per dataset, not three. + +## Re-running a single module + +Each module stages its raw output to +`/.annotate_staging/episode_{N:06d}/.jsonl`. This makes +prompt iteration cheap: re-running one module overwrites only its own +JSONL, then the writer recomposes the final parquet. Disable modules you +don't want with `--plan.enabled=false` (and likewise +`--interjections.enabled` / `--vqa.enabled`) to test one at a time. + +## What the validator checks + +Before the writer runs, `StagingValidator` confirms: + +- every event row lands exactly on a real frame timestamp; +- no speech / interjection pairs are left orphaned; +- `plan` is refreshed at every interjection timestamp; +- `memory` rows fall on subtask boundaries (a warning, not an error); +- each VQA assistant `content` is valid JSON in one of the + bbox / keypoint / count / attribute / spatial shapes; +- every row goes to the column chosen by `column_for_style(style)`. + +Any error aborts the writer. Pass `--skip_validation=true` to override +while debugging. + +## Where each module's ideas come from + +- **`plan` — subtasks.** Hi Robot ([Shi 2025](https://arxiv.org/abs/2502.19417)) + for atom granularity ("pick up one piece of lettuce", "place bowl to + box"); Pi0.7 ([Physical Intelligence 2025](https://pi.website/pi07)) + for "how, not what" detail. +- **`plan` — memory.** MEM ([Torne 2026](https://arxiv.org/abs/2603.03596)): + keep only the minimal relevant information — preserve outcomes, drop + specific attributes. +- **`interjections`.** Hi Robot's scenario taxonomy: negative task, + situated correction, specific constraint, preference. Speech is a + tool-call-only atom + (`tool_calls=[{type:function, function:{name:"say", arguments:{text:...}}}]`). +- **`vqa`.** ECoT ([Zawalski 2024](https://arxiv.org/abs/2407.08693)) for + grounded features (pixel bounding boxes `[x_min, y_min, x_max, y_max]`, + keypoints) and Steerable VLA Policies + ([Zhao 2025](https://arxiv.org/abs/2509.07626)) for multi-abstraction + grounding. Pi0.7 also grounds answers across abstraction levels. + +When improving a module, tweak its prompt template in +`src/lerobot/annotations/steerable_pipeline/prompts/` rather than +rewriting from scratch. + +## Roughly how much it costs + +Per episode, the pipeline makes about `max_steps` plan calls, +`max_interjections_per_episode` interjection calls, and +`vqa_emission_hz × episode_seconds` VQA calls. With the defaults (8 +subtasks, 1 interjection, 1 Hz × 3 pairs) on a 30-second episode, that's +~50 VLM calls. + +Storage stays small: `language_persistent` is at most tens of KB per +episode (parquet dictionary-encodes the one entry that repeats across +frames), and `language_events` is empty on most frames — its size scales +with the number of emissions, not `num_frames × num_emissions`. diff --git a/examples/annotations/run_hf_job.py b/examples/annotations/run_hf_job.py new file mode 100644 index 000000000..a77e22f14 --- /dev/null +++ b/examples/annotations/run_hf_job.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Launch ``lerobot-annotate`` on a Hugging Face job (vllm + Qwen3.6-27B VLM). + +Spawns one single-GPU ``h200`` job that: + + 1. installs ``lerobot`` from ``main`` plus the annotation extras, + 2. boots one vllm server with Qwen3.6-27B (dense VLM), + 3. runs the plan / interjections / vqa modules across the dataset + in free-form mode (each episode generates its own subtasks + + memory), + 4. uploads the annotated dataset to ``--new_repo_id`` (when set) + or back to ``--repo_id``. + +Usage: + + HF_TOKEN=hf_... uv run python examples/annotations/run_hf_job.py + +Adjust ``CMD`` (dataset, model, hub repo) and ``flavor`` below for your +run. For larger datasets, scale to ``h200x4`` and raise +``--vlm.parallel_servers`` / ``--vlm.num_gpus`` to match. +""" + +import os + +from huggingface_hub import get_token, run_job + +token = os.environ.get("HF_TOKEN") or get_token() +if not token: + raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`") + +CMD = ( + "apt-get update -qq && apt-get install -y -qq git ffmpeg && " + "pip install --no-deps " + "'lerobot @ git+https://github.com/huggingface/lerobot.git@main' && " + "pip install --upgrade-strategy only-if-needed " + "datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect " + "openai && " + "export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && " + "export VLLM_VIDEO_BACKEND=pyav && " + "lerobot-annotate " + "--repo_id=pepijn223/robocasa_pretrain_human300_v4 " + "--new_repo_id=pepijn223/robocasa_pretrain_human300_v4_annotated " + "--push_to_hub=true " + "--vlm.backend=openai " + "--vlm.model_id=Qwen/Qwen3.6-27B " + "--vlm.num_gpus=1 " + '--vlm.serve_command="vllm serve Qwen/Qwen3.6-27B ' + "--tensor-parallel-size 1 --max-model-len 32768 " + '--gpu-memory-utilization 0.8 --uvicorn-log-level warning --port {port}" ' + "--vlm.serve_ready_timeout_s=1800 " + # Qwen3.6 ships with thinking on; annotation wants plain JSON answers. + "--vlm.chat_template_kwargs='{\"enable_thinking\": false}'" +) + +job = run_job( + image="vllm/vllm-openai:latest", + command=["bash", "-c", CMD], + flavor="h200", + secrets={"HF_TOKEN": token}, + timeout="2h", +) +print(f"Job URL: {job.url}") +print(f"Job ID: {job.id}") diff --git a/pyproject.toml b/pyproject.toml index e43f8ef81..0dc86d7ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -229,6 +229,21 @@ vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen async = ["lerobot[grpcio-dep]", "lerobot[matplotlib-dep]"] peft = ["lerobot[transformers-dep]", "lerobot[peft-dep]"] +# Annotation pipeline (lerobot-annotate). The only backend is ``openai``, +# which talks to any OpenAI-compatible server (``vllm serve`` / +# ``transformers serve`` / hosted). Distributed runs use Hugging Face Jobs +# (see examples/annotations/run_hf_job.py). +annotations = [ + "lerobot[dataset]", + "lerobot[transformers-dep]", + "openai>=1.40,<2.0", + # ``vllm`` is intentionally NOT a hard dep: it pins an older torch, and + # uv's single unified lock would then cap ``torch`` for every extra + # (e.g. forcing 2.8 while ``torchcodec`` in [dataset] needs 2.11 -> ABI + # break in CI). The HF Jobs image (``vllm/vllm-openai``) provides vLLM; + # install it locally only if you run your own ``vllm serve``. +] + # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools>=1.73.1,<2.0.0", "mypy>=1.19.1", "ruff>=0.14.1", "lerobot[notebook]"] notebook = ["jupyter>=1.0.0,<2.0.0", "ipykernel>=6.0.0,<7.0.0"] @@ -323,6 +338,7 @@ lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" +lerobot-annotate="lerobot.scripts.lerobot_annotate:main" lerobot-rollout="lerobot.scripts.lerobot_rollout:main" # ---------------- Tool Configurations ---------------- @@ -341,7 +357,7 @@ torch = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] torchvision = [{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" }] [tool.setuptools.package-data] -lerobot = ["envs/*.json"] +lerobot = ["envs/*.json", "annotations/steerable_pipeline/prompts/*.txt"] [tool.setuptools.packages.find] where = ["src"] diff --git a/src/lerobot/annotations/__init__.py b/src/lerobot/annotations/__init__.py new file mode 100644 index 000000000..67782f192 --- /dev/null +++ b/src/lerobot/annotations/__init__.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/lerobot/annotations/steerable_pipeline/__init__.py b/src/lerobot/annotations/steerable_pipeline/__init__.py new file mode 100644 index 000000000..a8da5e05e --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/__init__.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Steerable annotation pipeline producing ``language_persistent`` and +``language_events`` columns for LeRobot datasets. + +The pipeline is decomposed into three independently runnable modules whose +outputs are staged per-episode before a final parquet rewrite: + +- :mod:`.modules.plan_subtasks_memory` (the ``plan`` module) — persistent styles +- :mod:`.modules.interjections_and_speech` (the ``interjections`` module) — event styles + speech +- :mod:`.modules.general_vqa` (the ``vqa`` module) — event-style VQA pairs +""" + +from .config import AnnotationPipelineConfig +from .validator import StagingValidator, ValidationReport +from .writer import LanguageColumnsWriter + +__all__ = [ + "AnnotationPipelineConfig", + "LanguageColumnsWriter", + "StagingValidator", + "ValidationReport", +] diff --git a/src/lerobot/annotations/steerable_pipeline/config.py b/src/lerobot/annotations/steerable_pipeline/config.py new file mode 100644 index 000000000..86d6cadd9 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/config.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +@dataclass +class PlanConfig: + """``plan`` module: subtasks + plan + memory + task augmentation.""" + + enabled: bool = True + + # ``task_aug`` rephrasings at t=0 (renderer rotates ${task} among them); 0 disables. + n_task_rephrasings: int = 10 + + # Derive the task from video instead of episode_task: off / if_short / always. + # Affects prompts only; ``meta/tasks.parquet`` is untouched. + derive_task_from_video: str = "if_short" + derive_task_min_words: int = 3 + + # --- Frame input: timestamped contact sheets (always on) --------------- + # The subtask describe/segment passes ALWAYS render the episode as + # macrodata/refiner-style contact sheets: sampled frames packed into JPEG + # grids with each frame's timestamp burned into its corner, so the VLM + # cites the exact source time of a boundary directly. This is far cheaper + # in vision tokens than one image per frame (≈2× faster subtask generation + # in practice), which is why the sampling is dense by default. + # + # ``frames_per_second`` is the sampling rate: 2.0 = one frame every 0.5s. + frames_per_second: float = 2.0 + # Frame budget per VLM call (= columns × rows × sheets). When a whole + # episode sampled at ``frames_per_second`` exceeds this, the episode is + # AUTOMATICALLY split into consecutive windows of + # ``max_frames_per_prompt`` frames each (one describe→segment call per + # window, still at the full ``frames_per_second`` density), and the + # per-window spans are merged + stitched into one contiguous cover. So an + # episode of any length is always covered at the full sampling density. + max_frames_per_prompt: int = 60 + contact_sheet_columns: int = 5 + contact_sheet_frames_per_sheet: int = 20 + contact_sheet_frame_width: int = 224 + contact_sheet_quality: int = 84 + + min_subtask_seconds: float = 1.5 + plan_max_steps: int = 8 + + # Narrate-only grounding pass before segmenting — best defense against subtasks + # invented from the task text (+1 VLM call/episode). + subtask_describe_first: bool = True + + # Emit ``style="plan"`` rows at each boundary; False = subtasks + memory only. + emit_plan: bool = True + + # Emit ``style="memory"`` rows at each boundary; False = subtasks (+ plan) only. + # Symmetric counterpart of ``emit_plan``. + emit_memory: bool = True + + # (subtask spans are always stitched to a contiguous full-episode cover; not configurable.) + + # Optional EgoMimic-style 5-axis task augmentation; replaces n_task_rephrasings. + task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig()) + + +@dataclass +class TaskAugAxesConfig: + """5-axis t=0 task augmentation (EgoMimic-style): synonym / omit_arm / + omit_orientation / omit_grasp_method / combined. Replaces n_task_rephrasings + when enabled; each variant becomes a ``task_aug`` row. Axes with nothing to + omit emit fewer entries. Defaults (3+3+2+2+2) match EgoMimic.""" + + enabled: bool = False + + synonym_paraphrase: int = 3 + omit_arm: int = 3 + omit_orientation: int = 2 + omit_grasp_method: int = 2 + combined_omissions: int = 2 + + +@dataclass +class InterjectionsConfig: + """``interjections`` module: interjections + paired speech.""" + + enabled: bool = True + + # Each emits a paired (interjection, speech) row + a plan refresh at that ts. + max_interjections_per_episode: int = 3 + interjection_min_t: float = 2.0 + + # Frame window centered on the timestamp so the VLM sees motion, not one frame. + interjection_window_seconds: float = 2.0 + interjection_window_frames: int = 4 + + +@dataclass +class VqaConfig: + """``vqa`` module: general VQA.""" + + enabled: bool = True + vqa_emission_hz: float = 1.0 + K: int = 1 + """Consecutive frames per emission tick. The VLM grounds on the FIRST frame, + so K>1 smears stale labels onto moved frames. Default 1 (no smear).""" + question_types: tuple[str, ...] = ("bbox", "keypoint", "count", "attribute", "spatial") + + # True: ground VQA only on --vlm.camera_key (default: every camera). + restrict_to_default_camera: bool = False + + +@dataclass +class VlmConfig: + """Shared Qwen-VL client configuration.""" + + # Only ``openai`` (OpenAI-compatible vLLM server, auto-spawned when + # auto_serve=True); ``stub`` is for tests. + backend: str = "openai" + model_id: str = "Qwen/Qwen3.6-27B" + + # OpenAI-compatible endpoint; ``EMPTY`` key works for local servers. + api_base: str = "http://localhost:8000/v1" + api_key: str = "EMPTY" + + # Spawn a server if none answers api_base; False = fail fast on a remote. + auto_serve: bool = True + serve_port: int = 8000 + # Override the auto-serve command; ``{port}`` substituted per replica. + serve_command: str | None = None + + # Independent servers for round-robin routing (one per GPU). num_gpus=0 = one each. + parallel_servers: int = 1 + num_gpus: int = 0 + client_concurrency: int = 16 + serve_ready_timeout_s: float = 600.0 + + max_new_tokens: int = 512 + temperature: float = 0.2 + + # Auto-serve context length (None → 32768); other vLLM flags go in serve_command. + max_model_len: int | None = None + + # Camera for keyframes; None → first ``observation.images.*`` key. + camera_key: str | None = None + # Forwarded as extra_body.chat_template_kwargs (e.g. {"enable_thinking": false}). + chat_template_kwargs: dict[str, Any] | None = None + + +@dataclass +class ExecutorConfig: + """Executor settings (intra-process episode concurrency; distribution via HF Jobs).""" + + # Episodes processed concurrently per phase; main knob for saturating the servers. + episode_parallelism: int = 16 + + +@dataclass +class AnnotationPipelineConfig: + """Top-level config for ``lerobot-annotate`` (rewrites data shards in place).""" + + # Hub dataset: download source when ``root`` unset; push target when push_to_hub + # is on and ``new_repo_id`` unset. + repo_id: str | None = None + + # Separate push target (matches the LeRobot edit tools). Unset → push in place. + new_repo_id: str | None = None + + root: Path | None = None + + # Defaults to ``/.annotate_staging/``. + staging_dir: Path | None = None + + seed: int = 1729 + + plan: PlanConfig = field(default_factory=PlanConfig) + interjections: InterjectionsConfig = field(default_factory=InterjectionsConfig) + vqa: VqaConfig = field(default_factory=VqaConfig) + + vlm: VlmConfig = field(default_factory=VlmConfig) + executor: ExecutorConfig = field(default_factory=ExecutorConfig) + + skip_validation: bool = False + only_episodes: tuple[int, ...] | None = None + + # Keyframe decode backend forwarded to ``decode_video_frames``. None → + # library default (torchcodec when available, else PyAV). Or pin + # ``"torchcodec"`` / ``"pyav"`` explicitly. + video_backend: str | None = None + + # Upload to the Hub (new_repo_id if set, else repo_id; one must be set). + push_to_hub: bool = False + push_private: bool = False + push_commit_message: str | None = None + + def resolved_staging_dir(self, root: Path) -> Path: + return self.staging_dir if self.staging_dir is not None else root / ".annotate_staging" diff --git a/src/lerobot/annotations/steerable_pipeline/executor.py b/src/lerobot/annotations/steerable_pipeline/executor.py new file mode 100644 index 000000000..69d10bc89 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/executor.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""In-process executor that runs the annotation phases. + +The executor runs **six phases** in dependency order: + + phase 1: ``plan`` module (plan + subtasks + memory) + phase 2: ``interjections`` module (interjections + speech) + phase 3: ``plan`` plan-update pass — re-runs plan emission at every + interjection timestamp produced by phase 2 + phase 4: ``vqa`` module (VQA) + phase 5: validator + phase 6: writer + +Phase 3 is why the ``plan`` module must be re-entered after the +``interjections`` module — to refresh ``plan`` rows at interjection +timestamps. + +Distributed execution is provided by Hugging Face Jobs (see +``examples/annotations/run_hf_job.py``); the runner inside the job +invokes ``lerobot-annotate`` which uses this in-process executor. +Episode-level concurrency is controlled by +``ExecutorConfig.episode_parallelism``. +""" + +from __future__ import annotations + +import logging +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .config import AnnotationPipelineConfig +from .reader import EpisodeRecord, iter_episodes +from .staging import EpisodeStaging +from .validator import StagingValidator +from .writer import LanguageColumnsWriter + +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseResult: + """Summary of one pipeline phase across all episodes.""" + + name: str + episodes_processed: int + episodes_skipped: int + + +@dataclass +class PipelineRunSummary: + """Aggregated result returned by :meth:`Executor.run`.""" + + phases: list[PhaseResult] + written_paths: list[Path] + validation_report: Any # ValidationReport, kept Any to avoid import cycle + + +@dataclass +class Executor: + """Run all six phases over a dataset root in-process. + + Episode-level concurrency comes from ``ExecutorConfig.episode_parallelism`` + (a thread pool); cluster-level concurrency comes from running this + executor inside a Hugging Face Job. Tests construct the executor + directly with stub modules. + """ + + config: AnnotationPipelineConfig + plan: Any # PlanSubtasksMemoryModule + interjections: Any # InterjectionsAndSpeechModule + vqa: Any # GeneralVqaModule + writer: LanguageColumnsWriter + validator: StagingValidator + + def run(self, root: Path) -> PipelineRunSummary: + records = list(iter_episodes(root, only_episodes=self.config.only_episodes)) + n = len(records) + if n == 0: + raise ValueError(f"No episodes found under {root}/data/") + + print(f"[annotate] {n} episodes total", flush=True) + + staging_dir = self.config.resolved_staging_dir(root) + staging_dir.mkdir(parents=True, exist_ok=True) + + phases: list[PhaseResult] = [] + + # Phase 1: ``plan`` module (plan + subtasks + memory) + phases.append(self._run_module_phase("plan", records, staging_dir, self.plan)) + # Phase 2: ``interjections`` module (interjections + speech). It + # reads the ``plan`` module's subtask rows from the same staging + # tree to ground the interjection prompt in the correct local subtask. + phases.append(self._run_module_phase("interjections", records, staging_dir, self.interjections)) + # Phase 3: ``plan`` plan-update pass at interjection timestamps. + phases.append(self._run_plan_update_phase(records, staging_dir)) + # Phase 4: ``vqa`` module (VQA) + phases.append(self._run_module_phase("vqa", records, staging_dir, self.vqa)) + + print("[annotate] running validator...", flush=True) + report = self.validator.validate(records, staging_dir) + if not report.ok and not self.config.skip_validation: + raise RuntimeError(f"Staging validation failed: {report.summary()}") + print(f"[annotate] validator: {report.summary()}", flush=True) + + print(f"[annotate] writing parquet shards into {root}/data/...", flush=True) + written = self.writer.write_all(records, staging_dir, root) + print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True) + + # Keep meta/info.json aligned with the parquet schema we just wrote. + # Idempotent and additive: existing user metadata is preserved. + self._ensure_annotation_metadata_in_info(root) + + return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report) + + @staticmethod + def _ensure_annotation_metadata_in_info(root: Path) -> None: + """Write language features and canonical tools to ``meta/info.json``. + + ``LanguageColumnsWriter`` adds ``language_persistent`` and + ``language_events`` to parquet shards. The metadata must advertise + those columns too, otherwise non-streaming ``LeRobotDataset`` loads + cast against the old schema and fail on the extra parquet columns. + """ + from lerobot.datasets.io_utils import load_info, write_info # noqa: PLC0415 + from lerobot.datasets.language import SAY_TOOL_SCHEMA, language_feature_info # noqa: PLC0415 + + info_path = root / "meta" / "info.json" + if not info_path.exists(): + return + try: + info = load_info(root) + except Exception as exc: # noqa: BLE001 + print(f"[annotate] could not read {info_path}: {exc}", flush=True) + return + + changed = False + + merged_features = {**info.features, **language_feature_info()} + if merged_features != info.features: + info.features = merged_features + changed = True + + existing = info.tools or [] + names = {(t.get("function") or {}).get("name") for t in existing if isinstance(t, dict)} + if SAY_TOOL_SCHEMA["function"]["name"] not in names: + info.tools = [*existing, SAY_TOOL_SCHEMA] + changed = True + + if changed: + write_info(info, root) + print( + "[annotate] meta/info.json: " + f"language_features={list(language_feature_info())}, " + f"tools={[t['function']['name'] for t in (info.tools or [])]}", + flush=True, + ) + + def _run_module_phase( + self, + name: str, + records: list[EpisodeRecord], + staging_dir: Path, + module: Any, + ) -> PhaseResult: + if not module.enabled: + print(f"[annotate] phase={name} skipped (module disabled)", flush=True) + return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records)) + n = len(records) + parallelism = max(1, min(self.config.executor.episode_parallelism, n)) + print( + f"[annotate] phase={name} starting on {n} episode(s) (parallelism={parallelism})", + flush=True, + ) + t0 = time.time() + + def _do(idx_record: tuple[int, EpisodeRecord]) -> tuple[int, int, float]: + i, record = idx_record + ep_start = time.time() + staging = EpisodeStaging(staging_dir, record.episode_index) + module.run_episode(record, staging) + return i, record.episode_index, time.time() - ep_start + + processed = 0 + if parallelism == 1: + for i, record in enumerate(records, 1): + _, ep_idx, elapsed = _do((i, record)) + processed += 1 + print( + f"[annotate] {name} episode {i}/{n} (idx={ep_idx}) done in {elapsed:.1f}s", + flush=True, + ) + else: + with ThreadPoolExecutor(max_workers=parallelism) as pool: + futures = [pool.submit(_do, (i, r)) for i, r in enumerate(records, 1)] + for fut in as_completed(futures): + i, ep_idx, elapsed = fut.result() + processed += 1 + print( + f"[annotate] {name} episode {processed}/{n} " + f"(idx={ep_idx}, submit_order={i}) done in {elapsed:.1f}s", + flush=True, + ) + total = time.time() - t0 + print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True) + return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0) + + def _run_plan_update_phase( # noqa: PLR0915 + self, records: list[EpisodeRecord], staging_dir: Path + ) -> PhaseResult: + """Re-emit ``plan`` rows at each timestamp the ``interjections`` module produced. + + The ``plan`` module owns the prompt; the ``interjections`` module + produced the timestamps. This phase therefore calls back into the + ``plan`` module with the interjection timestamps so its existing + prompt path is reused. + """ + if not self.plan.enabled or not self.interjections.enabled: + return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records)) + processed = 0 + for record in records: + staging = EpisodeStaging(staging_dir, record.episode_index) + interjection_rows = [ + row for row in staging.read("interjections") if row.get("style") == "interjection" + ] + interjection_times = [float(row["timestamp"]) for row in interjection_rows] + interjection_texts = [str(row.get("content") or "") for row in interjection_rows] + if interjection_times: + self.plan.run_plan_updates(record, staging, interjection_times, interjection_texts) + processed += 1 + # Episodes without any interjections are skipped (no plan refresh + # needed); count them so the summary's processed+skipped == total. + return PhaseResult( + name="plan_update", + episodes_processed=processed, + episodes_skipped=len(records) - processed, + ) diff --git a/src/lerobot/annotations/steerable_pipeline/frames.py b/src/lerobot/annotations/steerable_pipeline/frames.py new file mode 100644 index 000000000..a6c904673 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/frames.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keyframe extraction for the annotation pipeline. + +Modules attach decoded camera frames to their VLM prompts so the model can +ground subtask decomposition, interjection scenarios, and VQA in actual +visual content. The pipeline shares one provider across modules and one +episode at a time, with a small per-episode cache so multiple modules +querying the same timestamp pay decode cost once. +""" + +from __future__ import annotations + +import io +import logging +import math +import threading +from collections.abc import Sequence +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Protocol + +import PIL.Image +import torch + +from lerobot.configs.video import VideoEncoderConfig +from lerobot.datasets.video_utils import decode_video_frames, reencode_video + +from .reader import EpisodeRecord, snap_to_frame + +logger = logging.getLogger(__name__) + + +class FrameProvider(Protocol): + """Decodes camera frames at episode-relative timestamps.""" + + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` feature keys this provider can decode.""" + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + """Return one decoded frame per timestamp from ``camera_key`` (or default). + + Frames are ``torch.Tensor`` (``C, H, W`` uint8) — the shape + :func:`lerobot.datasets.video_utils.decode_video_frames` returns. + :func:`to_image_blocks` converts them to PIL only at the VLM-message + boundary. + + Empty list if the camera is unavailable. ``camera_key=None`` falls back + to the provider's default camera so existing single-camera callers + (the ``plan`` and ``interjections`` modules) keep working unchanged. + """ + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + """Return up to ``max_frames`` decoded frames covering the whole episode. + + Sampling is uniform across the episode duration. Frames are + ``torch.Tensor`` (``C, H, W`` uint8); :func:`to_video_block` wraps + them into one ``{"type":"video", "video":}`` block for a + Qwen-VL-compatible model that pools temporally itself. Empty list if + no camera available. + """ + + +@dataclass +class _NullProvider: + """No-op provider used when the dataset has no video keys or in tests.""" + + @property + def camera_keys(self) -> list[str]: + return [] + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + return [] + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + return [] + + +def null_provider() -> FrameProvider: + return _NullProvider() + + +@dataclass +class VideoFrameProvider: + """Decodes frames from the dataset's ``observation.images.*`` streams. + + By default the *first* camera key is used for the ``plan`` module + (subtask decomposition) and the ``interjections`` module (interjection + scenarios) — those prompts care about *what is happening*, not which + angle. The ``vqa`` module instead iterates over every camera in + :attr:`camera_keys` so each frame's + grounded answer (bbox/keypoint/...) is tagged with the camera it was + grounded against. + + ``camera_key`` overrides the default-camera choice but does not restrict + :attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` / + ``video_for_episode`` to read a non-default stream. + + Caches up to ``cache_size`` decoded frames per process to keep + co-timestamped ``interjections`` + ``plan`` plan-update calls cheap. + """ + + root: Path + camera_key: str | None = None + tolerance_s: float = 1e-2 + cache_size: int = 256 + # Keyframe decode backend forwarded to + # :func:`lerobot.datasets.video_utils.decode_video_frames`. ``None`` + # uses the library default (torchcodec when available, else PyAV). + video_backend: str | None = None + _meta: Any = field(default=None, init=False, repr=False) + _cache: dict = field(default_factory=dict, init=False, repr=False) + _camera_keys: list[str] = field(default_factory=list, init=False, repr=False) + # Pipeline runs the three module phases under a ThreadPoolExecutor (see + # ``ExecutorConfig.episode_parallelism``); guard the dict cache and the + # one-shot warn flag against concurrent updates from worker threads. + _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) + # Serializes decode_video_frames calls: torchcodec hands out one + # ``VideoDecoder`` per file from a process-wide cache, and the decoder + # is not safe to drive from multiple threads at once. + _decode_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) + _warned_decode_fail: bool = field(default=False, init=False, repr=False) + + def __post_init__(self) -> None: + from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415 + + self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root) + # Only ``video_keys`` are decodable here: the clip/decode paths read + # ``videos//from_timestamp`` from episode metadata, which exists + # only for video-stored cameras. Image-stored cameras (also in + # ``camera_keys``) would KeyError, so restrict the list — and the + # default — to video keys. + keys = list(self._meta.video_keys) + # Last-resort fallback: if metadata didn't surface any video keys but + # the caller explicitly named a camera (``--vlm.camera_key=...``), + # trust them — the key is by definition known to exist on the dataset. + if not keys and self.camera_key: + keys = [self.camera_key] + self._camera_keys = keys + if self.camera_key is None: + self.camera_key = keys[0] if keys else None + + @property + def camera_keys(self) -> list[str]: + """All ``observation.images.*`` keys available on this dataset.""" + return list(self._camera_keys) + + def frames_at( + self, + record: EpisodeRecord, + timestamps: list[float], + camera_key: str | None = None, + ) -> list[Any]: + target = camera_key if camera_key is not None else self.camera_key + if not timestamps or target is None: + return [] + # Snap each request to the nearest real frame timestamp: callers + # sample uniform grids whose points land mid-frame, and + # ``decode_video_frames`` rejects queries farther than + # ``tolerance_s`` from a decodable frame. Snapping also dedupes + # repeat queries through the cache. + if record.frame_timestamps: + timestamps = [snap_to_frame(float(ts), record.frame_timestamps) for ts in timestamps] + + out: list[Any] = [] + misses: list[float] = [] + miss_indices: list[int] = [] + with self._lock: + for i, ts in enumerate(timestamps): + key = (record.episode_index, target, round(float(ts), 6)) + cached = self._cache.get(key) + if cached is not None: + out.append(cached) + else: + out.append(None) + misses.append(float(ts)) + miss_indices.append(i) + + if misses: + decoded = self._decode(record.episode_index, misses, target) + # ``_decode`` returns exactly one frame per requested timestamp, + # or an empty list if decoding failed wholesale. A partial list + # would mean a frame/timestamp misalignment, so only pair them up + # when the counts match (``strict=True`` then guards regressions). + if len(decoded) == len(miss_indices): + with self._lock: + for i, frame in zip(miss_indices, decoded, strict=True): + out[i] = frame + key = (record.episode_index, target, round(float(timestamps[i]), 6)) + if len(self._cache) >= self.cache_size: + self._cache.pop(next(iter(self._cache))) + self._cache[key] = frame + # filter out any None left over from decode failures + return [frame for frame in out if frame is not None] + + def video_for_episode( + self, + record: EpisodeRecord, + max_frames: int, + camera_key: str | None = None, + ) -> list[Any]: + """Return up to ``max_frames`` frames uniformly sampled across the episode. + + The whole episode duration is covered; the model picks subtask + boundaries from the temporal pooling it does internally. Frames are + ``torch.Tensor`` (see :meth:`frames_at`). + """ + target = camera_key if camera_key is not None else self.camera_key + if max_frames <= 0 or target is None or not record.frame_timestamps: + return [] + n_frames = min(max_frames, len(record.frame_timestamps)) + if n_frames == len(record.frame_timestamps): + timestamps = list(record.frame_timestamps) + else: + t0 = record.frame_timestamps[0] + t_last = record.frame_timestamps[-1] + if t_last <= t0: + timestamps = [float(t0)] * n_frames + else: + step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0 + timestamps = [float(t0 + i * step) for i in range(n_frames)] + return self.frames_at(record, timestamps, camera_key=target) + + def episode_clip_path(self, record: EpisodeRecord, cache_dir: Path) -> Path | None: + """Extract the episode's subclip to ``cache_dir/ep_{idx:06d}.mp4``. + + Returns ``None`` if the dataset has no video tracks or extraction + failed. Skips re-extract when the cached clip already exists. + Re-encodes to H.264 via + :func:`lerobot.datasets.video_utils.reencode_video` so the resulting + mp4 is decodable by every downstream video processor — stream-copy + would inherit the source codec (often AV1 in modern LeRobot + datasets), which vllm's libav build cannot decode. + """ + if self.camera_key is None: + return None + cache_dir.mkdir(parents=True, exist_ok=True) + out_path = cache_dir / f"ep_{record.episode_index:06d}.mp4" + if out_path.exists() and out_path.stat().st_size > 0: + return out_path + ep = self._meta.episodes[record.episode_index] + from_timestamp = float(ep[f"videos/{self.camera_key}/from_timestamp"]) + to_timestamp = float(ep[f"videos/{self.camera_key}/to_timestamp"]) + src = self.root / self._meta.get_video_file_path(record.episode_index, self.camera_key) + encoder = VideoEncoderConfig(vcodec="h264", pix_fmt="yuv420p", g=None, crf=23, preset="ultrafast") + try: + reencode_video( + src, + out_path, + camera_encoder=encoder, + overwrite=True, + start_time_s=from_timestamp, + end_time_s=to_timestamp, + ) + except Exception: + logger.warning( + "clip extraction failed for episode %s (%s)", record.episode_index, src, exc_info=True + ) + return None + return out_path if out_path.exists() and out_path.stat().st_size > 0 else None + + def _decode(self, episode_index: int, timestamps: list[float], camera_key: str) -> list[Any]: + """Decode ``timestamps`` from the episode's video as ``(C, H, W)`` tensors. + + Delegates to :func:`lerobot.datasets.video_utils.decode_video_frames` + (torchcodec when available, PyAV otherwise; ``video_backend`` pins + one explicitly). Returns one frame per requested timestamp, or ``[]`` + if decoding failed — callers treat ``[]`` as "no frames available". + """ + ep = self._meta.episodes[episode_index] + from_timestamp = ep[f"videos/{camera_key}/from_timestamp"] + shifted = [from_timestamp + ts for ts in timestamps] + video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key) + + try: + # The module phases decode under a ThreadPoolExecutor (see + # ``ExecutorConfig.episode_parallelism``) but torchcodec's cached + # per-file decoder is single-threaded, so serialize decodes on a + # dedicated lock. Frame extraction is a small fraction of episode + # wall time (VLM calls dominate), so the contention is cheap. + with self._decode_lock: + # Stacked ``(N, C, H, W)`` uint8 tensor; one row per timestamp. + decoded = decode_video_frames( + video_path, shifted, self.tolerance_s, backend=self.video_backend, return_uint8=True + ) + return list(decoded) + except Exception as exc: + # Log loudly the first time so a silent vqa-module no-op (every + # prompt skipped because frames_at returned []) is debuggable from + # the job log instead of post-hoc parquet inspection. Subsequent + # failures stay quiet. + with self._lock: + already_warned = self._warned_decode_fail + if not already_warned: + self._warned_decode_fail = True + if not already_warned: + logger.warning( + "VideoFrameProvider._decode failed for episode=%s camera=%s video_path=%s backend=%s: %s", + episode_index, + camera_key, + video_path, + self.video_backend, + exc, + exc_info=exc, + ) + return [] + + +def make_frame_provider( + root: Path, camera_key: str | None = None, video_backend: str | None = None +) -> FrameProvider: + """Build a :class:`VideoFrameProvider` if videos are present, else null.""" + try: + provider = VideoFrameProvider(root=root, camera_key=camera_key, video_backend=video_backend) + except Exception: + return null_provider() + if provider.camera_key is None: + return null_provider() + return provider + + +def _frame_to_pil(frame: Any) -> Any: + """Materialise a decoded frame as a ``PIL.Image`` for the VLM message. + + Frames flow through the provider as ``torch.Tensor`` (``C, H, W`` uint8, + straight from :func:`decode_video_frames`); PIL is only created here, at + the VLM-message boundary, because the chat backends expect PIL images / + data URLs. Non-tensor inputs (e.g. test stubs) pass through untouched. + """ + if not isinstance(frame, torch.Tensor): + return frame + array = frame.detach().cpu() + if array.ndim == 3 and array.shape[0] in (1, 3): + array = array.permute(1, 2, 0) # (C, H, W) -> (H, W, C) + if array.shape[-1] == 1: + array = array.squeeze(-1) + return PIL.Image.fromarray(array.to(torch.uint8).numpy()) + + +def to_image_blocks(frames: list[Any]) -> list[dict[str, Any]]: + """Convert decoded frames to Qwen-VL-compatible image content blocks.""" + return [{"type": "image", "image": _frame_to_pil(frame)} for frame in frames] + + +def to_video_block(frames: list[Any]) -> list[dict[str, Any]]: + """Wrap a list of decoded frames as one Qwen-VL video block. + + Returns ``[]`` when the list is empty, so the caller can splat the result + into a content array without a separate emptiness check. + """ + if not frames: + return [] + return [{"type": "video", "video": [_frame_to_pil(frame) for frame in frames]}] + + +def to_video_url_block(url: str | None, fps: float = 2.0) -> list[dict[str, Any]]: + """Wrap a video file URL as one ``video_url`` block. + + Used by the ``openai`` backend (transformers serve / vllm serve / + ktransformers serve), where the server handles frame sampling. + Returns ``[]`` when ``url`` is ``None`` so the caller can splat. + """ + if not url: + return [] + return [{"type": "video_url", "video_url": {"url": url}, "fps": fps}] + + +def _draw_timestamp_badge(image: PIL.Image.Image, timestamp: float) -> PIL.Image.Image: + """Burn ``timestamp`` (seconds) into the top-left corner of ``image``. + + A solid black badge with white text, so a VLM reading a contact sheet can + cite the exact source time of each tile (e.g. ``012.50s``) directly, + instead of the caller having to map tile position back to time. Mirrors + the macrodata/refiner contact-sheet convention. + """ + from PIL import ImageDraw, ImageFont + + result = image.copy() + draw = ImageDraw.Draw(result) + font = ImageFont.load_default() + label = f"{timestamp:06.2f}s" + left, top, right, bottom = draw.textbbox((0, 0), label, font=font) + text_w, text_h = right - left, bottom - top + pad = max(3, round(min(image.width, image.height) * 0.018)) + draw.rectangle((0, 0, text_w + pad * 2, text_h + pad * 2), fill=(0, 0, 0)) + draw.text((pad - left, pad - top), label, fill=(255, 255, 255), font=font) + return result + + +def to_contact_sheet_blocks( + frames: Sequence[Any], + timestamps: Sequence[float], + *, + columns: int = 5, + frames_per_sheet: int = 20, + frame_width: int = 224, + quality: int = 84, +) -> list[dict[str, Any]]: + """Pack decoded frames into timestamped JPEG contact-sheet image blocks. + + Each frame is resized to ``frame_width`` wide, stamped with its + episode-relative timestamp, and tiled row-major into grids of + ``frames_per_sheet`` (``columns`` wide). One ``{"type":"image", ...}`` + block is returned per grid; many frames collapse into a few images, so a + long episode's temporal coverage stays dense at a fraction of the vision + tokens N separate frames would cost. ``frames`` and ``timestamps`` must be + aligned and equal length. Returns ``[]`` for empty input. + """ + from PIL import Image + + if not frames: + return [] + columns = max(1, columns) + frames_per_sheet = max(1, frames_per_sheet) + rows_per_sheet = math.ceil(frames_per_sheet / columns) + + tiles: list[PIL.Image.Image] = [] + for ts, frame in zip(timestamps, frames, strict=False): + img = _frame_to_pil(frame) + if not isinstance(img, PIL.Image.Image): + continue + img = img.convert("RGB") + if img.width != frame_width: + height = max(1, round(img.height * frame_width / img.width)) + img = img.resize((frame_width, height), resample=Image.Resampling.BILINEAR) + tiles.append(_draw_timestamp_badge(img, float(ts))) + if not tiles: + return [] + + blocks: list[dict[str, Any]] = [] + for start in range(0, len(tiles), frames_per_sheet): + chunk = tiles[start : start + frames_per_sheet] + cell_w = max(tile.width for tile in chunk) + cell_h = max(tile.height for tile in chunk) + sheet = Image.new("RGB", (cell_w * columns, cell_h * rows_per_sheet), color=(0, 0, 0)) + for i, tile in enumerate(chunk): + x = (i % columns) * cell_w + y = (i // columns) * cell_h + sheet.paste(tile, (x, y)) + # JPEG round-trip at ``quality`` to match the refiner convention and + # shrink the wire payload; vision-token count is set by resolution, so + # the real saving is the grid packing, not the codec. + buf = io.BytesIO() + sheet.save(buf, format="JPEG", quality=quality) + buf.seek(0) + blocks.append({"type": "image", "image": Image.open(buf).convert("RGB")}) + return blocks diff --git a/src/lerobot/annotations/steerable_pipeline/modules/__init__.py b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py new file mode 100644 index 000000000..e9ff8ed23 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/__init__.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .general_vqa import GeneralVqaModule +from .interjections_and_speech import InterjectionsAndSpeechModule +from .plan_subtasks_memory import PlanSubtasksMemoryModule + +__all__ = [ + "GeneralVqaModule", + "InterjectionsAndSpeechModule", + "PlanSubtasksMemoryModule", +] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py new file mode 100644 index 000000000..cdc87b579 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/general_vqa.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``vqa`` module: general VQA at a timed cadence. + +Every ``1/hz`` seconds an emission tick fires; each tick anchors ``K`` +consecutive frames, and every anchored frame gets its own VQA pair. Each +pair is grounded on that single anchor frame — there is no per-pair frame +window. For datasets with multiple cameras, every anchored frame produces +one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is +generated against that camera's frame and stamped with the matching +``camera`` field on the emitted rows. The resolver disambiguates via +``camera=...``; recipes that consume VQA do so through one sub-recipe +per camera (see ``recipes/pi05_hirobot.yaml``). + +Within a single (frame, camera) we still emit at most one ``(vqa, user)`` +and one ``(vqa, assistant)`` row, so the resolver contract stays scalar. + +Question types covered (per the plan's ``vqa`` table): bbox, keypoint, +count, attribute, spatial. The assistant's ``content`` is a JSON string +whose schema depends on the question type. Malformed JSON triggers one +retry inside :meth:`VlmClient.generate_json`. +""" + +from __future__ import annotations + +import json +import logging +import random +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import VqaConfig +from ..frames import FrameProvider, null_provider, to_image_blocks +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord +from ..staging import EpisodeStaging +from ..validator import classify_vqa_answer +from ..vlm_client import VlmClient + + +def _emission_anchor_indices(frame_timestamps: Sequence[float], hz: float, k: int) -> list[int]: + """Return the relative frame indices to anchor VQA emissions to. + + For each emission tick (every ``1/hz`` seconds), we anchor ``k`` + consecutive frames starting at the tick. Ticks fall on the nearest + available source frame timestamp. + """ + if hz <= 0 or k <= 0 or not frame_timestamps: + return [] + t0 = frame_timestamps[0] + t_last = frame_timestamps[-1] + period = 1.0 / hz + indices: list[int] = [] + t = t0 + while t <= t_last + 1e-9: + # find the index of the nearest frame to t + nearest_i = min(range(len(frame_timestamps)), key=lambda i: abs(frame_timestamps[i] - t)) + for offset in range(k): + j = nearest_i + offset + if j >= len(frame_timestamps): + break + if not indices or indices[-1] != j: + indices.append(j) + t += period + # dedupe while preserving order + seen: set[int] = set() + deduped: list[int] = [] + for i in indices: + if i in seen: + continue + seen.add(i) + deduped.append(i) + return deduped + + +@dataclass +class GeneralVqaModule: + """Emit grounded VQA pairs at a timed cadence.""" + + vlm: VlmClient + config: VqaConfig + seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) + _warned_no_camera: bool = field(default=False, init=False, repr=False) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + if not record.frame_timestamps: + staging.write("vqa", []) + return + rng = random.Random(f"{self.seed}:{record.episode_index}:vqa") + anchor_idx = _emission_anchor_indices( + record.frame_timestamps, self.config.vqa_emission_hz, self.config.K + ) + cameras = self._target_cameras() + if not cameras: + # No camera available — emit nothing rather than producing + # untagged rows that would fail validation. Surface a loud one- + # time warning so this is never silently a no-op. + if not self._warned_no_camera: + logging.getLogger(__name__).warning( + "vqa module found no cameras on the frame provider — " + "every episode will emit zero VQA rows. Check that the " + "dataset declares observation.images.* features in " + "meta/info.json; passing --vlm.camera_key= at the " + "CLI now also seeds the cameras list as a fallback." + ) + self._warned_no_camera = True + staging.write("vqa", []) + return + + # Build all messages first (one per (frame, camera)), then issue them + # as a single batched generate_json call so the client can fan them + # out concurrently. + per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = [] + for idx in anchor_idx: + ts = float(record.frame_timestamps[idx]) + qtype = rng.choice(self.config.question_types) + for camera in cameras: + messages = self._build_messages(record, qtype, ts, camera) + # Skip cameras that decoded to zero frames at this ts: no point + # asking the VLM to ground a bbox without an image. + if not _has_image_block(messages): + continue + per_call.append((ts, camera, qtype, messages)) + + if not per_call: + staging.write("vqa", []) + return + + results = self.vlm.generate_json([m for _, _, _, m in per_call]) + + rows: list[dict[str, Any]] = [] + for (ts, camera, _qtype, _messages), result in zip(per_call, results, strict=True): + qa = self._postprocess(result) + if qa is None: + continue + question, answer = qa + rows.append( + { + "role": "user", + "content": question, + "style": "vqa", + "timestamp": ts, + "camera": camera, + "tool_calls": None, + } + ) + rows.append( + { + "role": "assistant", + "content": json.dumps(answer, sort_keys=True), + "style": "vqa", + "timestamp": ts, + "camera": camera, + "tool_calls": None, + } + ) + staging.write("vqa", rows) + + def _target_cameras(self) -> list[str]: + """Return the cameras the ``vqa`` module should iterate per anchored frame. + + Defaults to every camera the provider exposes. Datasets with no + cameras (or test/null providers) yield an empty list, which makes + ``run_episode`` a no-op. + + When ``config.restrict_to_default_camera`` is set, VQA grounds on + only the provider's default camera (the single ``--vlm.camera_key`` + stream), matching the plan / interjection modules so the whole + pipeline focuses on one view. + """ + all_cameras = list(getattr(self.frame_provider, "camera_keys", []) or []) + if getattr(self.config, "restrict_to_default_camera", False): + default = getattr(self.frame_provider, "camera_key", None) + if default and default in all_cameras: + return [default] + # ``restrict_to_default_camera`` is set but the configured default + # isn't one the provider exposes. Returning it anyway would make + # ``_decode`` raise a KeyError deep in frame extraction, so warn and + # fall through to every available camera instead. + if default: + logging.getLogger(__name__).warning( + "restrict_to_default_camera is set but camera_key=%r is not in the " + "provider's cameras %s; grounding VQA on all available cameras instead.", + default, + all_cameras, + ) + return all_cameras + + def _build_messages( + self, + record: EpisodeRecord, + question_type: str, + frame_timestamp: float, + camera_key: str, + ) -> list[dict[str, Any]]: + prompt = load_prompt("vqa").format( + episode_task=record.episode_task, + question_type=question_type, + ) + images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + return [{"role": "user", "content": content}] + + def _postprocess(self, result: Any) -> tuple[str, dict[str, Any]] | None: + if not isinstance(result, dict): + return None + question = result.get("question") + answer = result.get("answer") + if not isinstance(question, str) or not question.strip(): + return None + if not isinstance(answer, dict): + return None + # The validator will enforce shape; here we just sanity-check that the + # answer matches *some* known shape so we can drop garbage early. + if classify_vqa_answer(answer) is None: + return None + return question.strip(), answer + + +def _has_image_block(messages: list[dict[str, Any]]) -> bool: + """Return True if any user content block is a populated image block.""" + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for block in content: + if isinstance(block, dict) and block.get("type") == "image": + return True + return False diff --git a/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py new file mode 100644 index 000000000..616f9ce1b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/interjections_and_speech.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms). + +Two sub-passes: + +1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the + canonical task). No interjection row — the canonical task is already the + user utterance from ``meta/tasks.parquet``. + +2. For mid-episode interruptions, emit a co-timestamped pair: + {role:user, style:interjection, content:} + speech atom (role:assistant, style:None, tool_calls=[say(...)]) + Both rows go in ``language_events`` at the same timestamp. + +The ``plan`` module's :meth:`run_plan_updates` reuses this module's +interjection timestamps to refresh the ``plan`` row at the same instant. +""" + +from __future__ import annotations + +import random +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import InterjectionsConfig +from ..frames import FrameProvider, null_provider, to_image_blocks +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient +from ..writer import speech_atom + + +@dataclass +class InterjectionsAndSpeechModule: + """Generate task-start speech and mid-episode interjection/speech pairs.""" + + vlm: VlmClient + config: InterjectionsConfig + seed: int = 1729 + frame_provider: FrameProvider = field(default_factory=null_provider) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + if record.frame_timestamps: + t0 = float(record.frame_timestamps[0]) + initial = self._initial_speech(record) + if initial: + rows.append(speech_atom(t0, initial)) + # Pull the ``plan`` module's subtask spans for this episode so the + # interjection prompt can ground itself in the actual current + # subtask at each chosen timestamp. The ``plan`` module ran first. + episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None + subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t) + rows.extend(self._mid_episode_interjections(record, subtask_spans)) + staging.write("interjections", rows) + + @staticmethod + def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None: + current: str | None = None + for span in spans: + if float(span["start"]) <= t: + current = span.get("text") + else: + break + return current + + def _initial_speech(self, record: EpisodeRecord) -> str | None: + prompt = load_prompt("interjections_initial_speech").format( + episode_task=record.episode_task, + ) + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict) and isinstance(result.get("text"), str): + text = result["text"].strip() + if text: + return text + return None + + def _mid_episode_interjections( + self, + record: EpisodeRecord, + subtask_spans: Sequence[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Generate interjections aligned with the actual demo trajectory. + + Teleop data is frozen — the robot already executed every step in + the video. A *counterfactual* interjection like "actually skip + the wipe" contradicts what then happens in the video, which is + what qwen36moe-10/11 surfaced as low-quality interjections. + + Instead, anchor every interjection at a subtask boundary and + write it as a natural user request for the *upcoming* subtask. + The robot's visible next behavior IS the interjection's effect, + so the training signal stays consistent: interjection text → + plan refresh → action stream all line up. + """ + if self.config.max_interjections_per_episode <= 0: + return [] + if len(subtask_spans) < 2: + # Need at least one transition (subtask 0 → subtask 1). + return [] + # Deterministic per-episode RNG so reruns are stable across SLURM jobs. + rng = random.Random(f"{self.seed}:{record.episode_index}:interjection") + + # Boundaries: the start time of every subtask except the first + # (which is just t0 and is covered by the initial-task speech atom). + boundaries: list[tuple[float, str, str]] = [] + for i in range(1, len(subtask_spans)): + ts = float(subtask_spans[i]["start"]) + if ts < self.config.interjection_min_t: + continue + prev_text = (subtask_spans[i - 1].get("text") or "").strip() + next_text = (subtask_spans[i].get("text") or "").strip() + if not next_text: + continue + boundaries.append((ts, prev_text, next_text)) + if not boundaries: + return [] + + n = min(self.config.max_interjections_per_episode, len(boundaries)) + chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0]) + + out: list[dict[str, Any]] = [] + for t, prev_subtask, next_subtask in chosen: + t_snap = snap_to_frame(t, record.frame_timestamps) + # Window straddles the boundary so the VLM sees the end of the + # previous subtask and the start of the next one — same + # conditioning the policy will see at training time. + window_ts = self._window_timestamps(t_snap, record.frame_timestamps) + prompt = load_prompt("interjections_interjection").format( + episode_task=record.episode_task, + prev_subtask=prev_subtask or "(starting from initial state)", + next_subtask=next_subtask, + timestamp=t_snap, + window_seconds=self.config.interjection_window_seconds, + ) + images = self.frame_provider.frames_at(record, window_ts) + content = [*to_image_blocks(images), {"type": "text", "text": prompt}] + messages = [{"role": "user", "content": content}] + result = self.vlm.generate_json([messages])[0] + if not isinstance(result, dict): + continue + interjection_text = result.get("interjection") + speech_text = result.get("speech") + if not isinstance(interjection_text, str) or not interjection_text.strip(): + continue + if not isinstance(speech_text, str) or not speech_text.strip(): + continue + out.append( + { + "role": "user", + "content": interjection_text.strip(), + "style": "interjection", + "timestamp": t_snap, + "tool_calls": None, + } + ) + out.append(speech_atom(t_snap, speech_text.strip())) + return out + + def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]: + """Return a small set of frame timestamps centered on ``t_anchor``. + + The window straddles the subtask boundary the interjection sits + on: roughly half the frames cover the end of the previous + subtask, half cover the start of the next one. The VLM therefore + sees BOTH what just finished AND what's about to start, which is + the conditioning we need to write a natural "now please do X" + request that matches the visible upcoming behavior. + """ + if not frame_timestamps: + return [t_anchor] + n = max(1, int(self.config.interjection_window_frames)) + if n == 1: + return [t_anchor] + window = float(self.config.interjection_window_seconds) + step = window / max(1, n - 1) + # Center the window on the anchor so half lands before, half after. + start_offset = -window / 2.0 + targets = [t_anchor + start_offset + step * i for i in range(n)] + first_ts = float(frame_timestamps[0]) + last_ts = float(frame_timestamps[-1]) + snapped: list[float] = [] + seen: set[float] = set() + for tgt in targets: + clamped = min(last_ts, max(first_ts, tgt)) + t = snap_to_frame(clamped, frame_timestamps) + if t not in seen: + seen.add(t) + snapped.append(t) + return snapped or [t_anchor] diff --git a/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py new file mode 100644 index 000000000..b6df6551c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/modules/plan_subtasks_memory.py @@ -0,0 +1,780 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""``plan`` module: subtask decomposition + plan + memory (PERSISTENT styles).""" + +from __future__ import annotations + +import logging +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from ..config import PlanConfig +from ..frames import ( + FrameProvider, + null_provider, + to_contact_sheet_blocks, +) +from ..prompts import load as load_prompt +from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame +from ..staging import EpisodeStaging +from ..vlm_client import VlmClient + +logger = logging.getLogger(__name__) + + +# Prepended to every describe / segment prompt so the VLM knows the images are +# timestamped contact-sheet grids, not a single video, and reads the burned-in +# per-tile timestamp when choosing boundaries. +def _contact_sheet_preamble(columns: int) -> str: + return ( + "CONTACT SHEETS — how to read the images below:\n" + f"- Each image is a grid of sampled video frames, {columns} per row, " + "with time running left-to-right then top-to-bottom (row-major).\n" + "- Each frame has its timestamp burned into the top-left corner, e.g. " + '"012.50s". Use that printed timestamp (not the tile position) when you ' + "choose start/end times; boundaries should land on or near a printed " + "timestamp.\n" + "- Frames continue across grids: an action may span the end of one sheet " + "and the start of the next, so do not place a boundary just because a new " + "image begins.\n\n" + ) + + +# Appended to every describe (and segment) prompt. A visual, causal definition +# of where one event ends and the next begins — adapted from macrodata/refiner — +# to sharpen cut points while the existing prompt keeps owning the imperative +# phrasing. +_CAUSAL_BOUNDARY_RULES = ( + "EVENT BOUNDARIES — where one event ends and the next begins:\n" + "- Start a new event whenever the world state changes: an object becomes " + "held (the gripper closes on it), an object is released (the gripper opens " + "and it stays put), an object reaches a new location, a lid/door/drawer " + "changes open/closed state, a tool starts or stops affecting a surface, or " + "contents visibly move (e.g. poured).\n" + "- If a single action changes the same state gradually and continuously, " + "keep it as ONE event — do not split it.\n" + "- If the same action repeats on different objects or target locations, " + "treat each repetition as a separate event.\n" + "- Do NOT create boundaries for idle time, camera motion, hesitation, or " + "tiny hand adjustments." +) + + +@dataclass +class PlanSubtasksMemoryModule: + """Generate subtask spans, plan, and memory rows. + + All output is persistent (lives in ``language_persistent``): + + - ``subtask`` rows: one per span, stamped at the span's *start* timestamp + (snapped to an exact frame). + - ``plan`` rows: emitted at ``t=0``; refreshed at every interjection + timestamp via :meth:`run_plan_updates` (called by the executor after + the ``interjections`` module completes). + - ``memory`` rows: emitted at each subtask boundary (= subtask start + timestamp from the second subtask onward). + """ + + vlm: VlmClient + config: PlanConfig + frame_provider: FrameProvider = field(default_factory=null_provider) + + @property + def enabled(self) -> bool: + return self.config.enabled + + def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None: + rows: list[dict[str, Any]] = [] + # Task driving every plan-module prompt: canonical episode_task, or a + # video-derived one when it's empty/placeholder (see derive_task_*). + effective_task = self._resolve_effective_task(record) + # task_aug rows at t=0: phrasings the renderer rotates ${task} through. + # Either the structured 5-axis taxonomy (task_aug_axes.enabled) or + # free-form n_task_rephrasings; the effective task is always emitted + # first so the rotation covers the source-of-truth phrasing. + t0 = float(record.frame_timestamps[0]) if record.frame_timestamps else 0.0 + variants: list[str] | None = None + if self.config.task_aug_axes.enabled and effective_task: + variants = self._generate_task_aug_by_axes(effective_task, self.config.task_aug_axes) + elif self.config.n_task_rephrasings > 0 and effective_task: + variants = self._generate_task_rephrasings(effective_task, n=self.config.n_task_rephrasings) + if variants is not None: + rows.extend(self._task_aug_rows([effective_task, *variants], t0)) + + subtask_spans = self._generate_subtasks(record, task=effective_task) + + # subtask rows + for span in subtask_spans: + rows.append( + { + "role": "assistant", + "content": span["text"], + "style": "subtask", + "timestamp": snap_to_frame(span["start"], record.frame_timestamps), + "tool_calls": None, + } + ) + # Plan rows at every subtask boundary (incl. t=0). The plan is a + # numbered list of still-todo subtasks, so re-emitting at each + # boundary makes it shrink as work progresses — ${plan} at frame t is + # exactly what's left to do. + if self.config.emit_plan: + for span in subtask_spans: + boundary_t = snap_to_frame(span["start"], record.frame_timestamps) + plan_text = self._generate_plan( + record, subtask_spans, refresh_t=boundary_t, task=effective_task + ) + if plan_text is not None: + rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": float(boundary_t), + "tool_calls": None, + } + ) + # memory rows at every subtask boundary except the very first start; + # skipped entirely when ``emit_memory`` is False (subtasks-only / plan-only). + prior_memory = "" + memory_boundaries = enumerate(subtask_spans[1:], start=1) if self.config.emit_memory else [] + for i, span in memory_boundaries: + completed = subtask_spans[i - 1]["text"] + remaining = [s["text"] for s in subtask_spans[i:]] + mem_text = self._generate_memory(record, prior_memory, completed, remaining, task=effective_task) + if mem_text: + ts = snap_to_frame(span["start"], record.frame_timestamps) + rows.append( + { + "role": "assistant", + "content": mem_text, + "style": "memory", + "timestamp": ts, + "tool_calls": None, + } + ) + prior_memory = mem_text + staging.write("plan", rows) + + # ------------------------------------------------------------------ + # Task derivation + rephrasings + # ------------------------------------------------------------------ + + _PLACEHOLDER_TASKS: frozenset[str] = frozenset( + { + "debug", + "test", + "tbd", + "todo", + "n/a", + "na", + "untitled", + "unnamed", + "default", + "placeholder", + } + ) + + def _resolve_effective_task(self, record: EpisodeRecord) -> str: + """Decide which task string drives the ``plan`` module for this episode. + + Returns the user-supplied ``record.episode_task`` unless + ``derive_task_from_video`` says otherwise (see config docstring). + Falls back gracefully to the canonical task if video derivation + fails. + """ + canonical = (record.episode_task or "").strip() + mode = (self.config.derive_task_from_video or "off").strip().lower() + if mode == "always": + derived = self._derive_task_from_video(record) + return derived or canonical + if mode == "if_short" and self._task_seems_bad(canonical): + derived = self._derive_task_from_video(record) + if derived: + return derived + return canonical + + def _task_seems_bad(self, task: str) -> bool: + if not task: + return True + if len(task.split()) < int(self.config.derive_task_min_words): + return True + return task.lower() in self._PLACEHOLDER_TASKS + + @staticmethod + def _task_aug_rows(phrasings: Sequence[str], t0: float) -> list[dict[str, Any]]: + """Build deduplicated ``task_aug`` rows (role=user) at ``t0``.""" + seen: set[str] = set() + rows: list[dict[str, Any]] = [] + for phrasing in phrasings: + key = phrasing.strip() + if not key or key in seen: + continue + seen.add(key) + rows.append( + {"role": "user", "content": key, "style": "task_aug", "timestamp": t0, "tool_calls": None} + ) + return rows + + # ------------------------------------------------------------------ + # VLM call helpers — every plan-module prompt follows the same shape: + # build messages → single VLM call → pull a named field. + # ------------------------------------------------------------------ + + def _vlm_field(self, messages: list[dict[str, Any]], field: str) -> Any: + """Run a single VLM call and return ``result[field]`` or ``None``. + + Centralizes the ``vlm.generate_json([m])[0]`` + ``isinstance(dict)`` + dance every prompt-call site needs. + """ + result = self.vlm.generate_json([messages])[0] + if isinstance(result, dict): + return result.get(field) + return None + + @staticmethod + def _text_message(text: str) -> list[dict[str, Any]]: + """One-shot text-only user message wrapped for ``generate_json``.""" + return [{"role": "user", "content": [{"type": "text", "text": text}]}] + + def _video_message( + self, + record: EpisodeRecord, + prompt: str, + window: tuple[float, float] | None = None, + ) -> list[dict[str, Any]]: + """User message combining the (optionally windowed) contact sheets with ``prompt``. + + The prompt is always prefixed with a short explanation of how to read + the timestamped grids, so the model treats them as one ordered + sequence of frames rather than unrelated images. + """ + prompt = _contact_sheet_preamble(self.config.contact_sheet_columns) + prompt + content = [*self._episode_video_block(record, window=window), {"type": "text", "text": prompt}] + return [{"role": "user", "content": content}] + + def _derive_task_from_video(self, record: EpisodeRecord) -> str | None: + """Ask the VLM "what is this video about" with no task hint at all.""" + text = self._vlm_field(self._video_message(record, load_prompt("plan_video_task")), "task") + return text.strip() if isinstance(text, str) and text.strip() else None + + def _generate_task_rephrasings(self, base_task: str, *, n: int) -> list[str]: + """Generate ``n`` text-only paraphrases of ``base_task``.""" + if n <= 0 or not base_task: + return [] + prompt = load_prompt("plan_task_rephrasings").format(base_task=base_task, n=n) + raw = self._vlm_field(self._text_message(prompt), "rephrasings") + if not isinstance(raw, list): + return [] + out = [item.strip().strip('"').strip("'") for item in raw if isinstance(item, str)] + return [s for s in out if s][:n] + + # ------------------------------------------------------------------ + # Structured 5-axis task augmentation (EgoMimic-style taxonomy) + # ------------------------------------------------------------------ + + def _generate_task_aug_by_axes(self, base_task: str, axes_cfg: Any) -> list[str]: + """One VLM call → variants along the 5-axis taxonomy. + + Variants from all axes are flattened into a single list (the + downstream pipeline doesn't need to know about the per-axis + bucketing — every variant becomes a ``task_aug`` row). Order + is preserved for reproducibility: synonym_paraphrase first, + then omit_arm, then omit_orientation, then omit_grasp_method, + then combined_omissions. + """ + if not base_task: + return [] + prompt = load_prompt("plan_task_aug_axes").format( + base_task=base_task, + n_synonym=axes_cfg.synonym_paraphrase, + n_omit_arm=axes_cfg.omit_arm, + n_omit_orientation=axes_cfg.omit_orientation, + n_omit_grasp_method=axes_cfg.omit_grasp_method, + n_combined=axes_cfg.combined_omissions, + ) + result = self.vlm.generate_json([self._text_message(prompt)])[0] + if not isinstance(result, dict): + return [] + ordered_axes = ( + "synonym_paraphrase", + "omit_arm", + "omit_orientation", + "omit_grasp_method", + "combined_omissions", + ) + flat: list[str] = [] + seen: set[str] = set() + for axis in ordered_axes: + entries = result.get(axis) + if not isinstance(entries, list): + continue + for item in entries: + if not isinstance(item, str): + continue + key = item.strip().strip('"').strip("'") + if not key or key in seen: + continue + seen.add(key) + flat.append(key) + return flat + + def _episode_video_block( + self, record: EpisodeRecord, window: tuple[float, float] | None = None + ) -> list[dict[str, Any]]: + """Timestamped contact sheets for the describe / segmentation prompts. + + Always renders the (optionally windowed) episode as contact sheets: + frames sampled at ``frames_per_second`` and packed into timestamped + JPEG grids. ``max_frames_per_prompt`` caps the frame count; whole + episodes that exceed it are windowed upstream in + :meth:`_generate_subtasks` so each call stays within budget while the + full episode keeps its sampling density. + + When ``window=(w0, w1)`` is given the badges are WINDOW-RELATIVE + (``ts - w0``) to match the window-relative time frame the + segmentation prompt works in (spans are offset back to absolute time + afterwards). + """ + if not record.frame_timestamps: + return [] + if window is not None: + w0, w1 = float(window[0]), float(window[1]) + dur = max(0.0, w1 - w0) + n = max(1, int(round(dur * self.config.frames_per_second)) + 1) + n = min(n, self.config.max_frames_per_prompt) + if n <= 1 or dur <= 0.0: + timestamps = [0.5 * (w0 + w1)] + else: + step = dur / (n - 1) + timestamps = [w0 + i * step for i in range(n)] + frames = self.frame_provider.frames_at(record, timestamps) + rel = [ts - w0 for ts in timestamps[: len(frames)]] + return self._contact_sheet_blocks(frames, rel) + episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] + n = max(1, int(round(episode_duration * self.config.frames_per_second)) + 1) + n = min(n, self.config.max_frames_per_prompt) + timestamps = self._uniform_episode_timestamps(record, n) + frames = self.frame_provider.frames_at(record, timestamps) + return self._contact_sheet_blocks(frames, timestamps[: len(frames)]) + + @staticmethod + def _uniform_episode_timestamps(record: EpisodeRecord, n: int) -> list[float]: + """``n`` episode-relative timestamps spanning ``[t0, t_last]`` uniformly.""" + ts = record.frame_timestamps + if n >= len(ts): + return [float(t) for t in ts] + t0, t_last = float(ts[0]), float(ts[-1]) + if t_last <= t0 or n <= 1: + return [t0] * max(1, n) + step = (t_last - t0) / (n - 1) + return [t0 + i * step for i in range(n)] + + def _contact_sheet_blocks(self, frames: list[Any], timestamps: list[float]) -> list[dict[str, Any]]: + """Build timestamped contact-sheet image blocks from decoded frames.""" + return to_contact_sheet_blocks( + frames, + timestamps, + columns=self.config.contact_sheet_columns, + frames_per_sheet=self.config.contact_sheet_frames_per_sheet, + frame_width=self.config.contact_sheet_frame_width, + quality=self.config.contact_sheet_quality, + ) + + def run_plan_updates( + self, + record: EpisodeRecord, + staging: EpisodeStaging, + interjection_times: Sequence[float], + interjection_texts: Sequence[str] | None = None, + ) -> None: + """Append additional ``plan`` rows at every interjection timestamp. + + Plans refresh ONLY on user interjections (event-driven). The + interjection text is forwarded into the prompt so the refreshed plan + reflects the user's correction. + """ + if not self.config.emit_plan: + return + existing = staging.read("plan") + # Pass the last frame timestamp so the final span is closed (else its + # end == start, zero duration, and a refresh inside it is missed). + episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None + spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t) + already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"} + new_rows = list(existing) + + texts: list[str | None] = ( + [None] * len(interjection_times) + if interjection_texts is None + else [str(t) if t else None for t in interjection_texts] + ) + for raw_t, inter_text in zip(interjection_times, texts, strict=True): + t = snap_to_frame(raw_t, record.frame_timestamps) + if t in already_planned: + continue + already_planned.add(t) + plan_text = self._generate_plan(record, spans, refresh_t=t, interjection=inter_text) + if plan_text is not None: + new_rows.append( + { + "role": "assistant", + "content": plan_text, + "style": "plan", + "timestamp": t, + "tool_calls": None, + } + ) + staging.write("plan", new_rows) + + def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]: + """Generate subtask spans, optionally via a multi-call quality chain. + + Single call (default): watch video → emit subtask JSON. + + Multi-call (opt-in, higher quality, more VLM calls): + 1. ``subtask_describe_first`` — a grounding pass that narrates + ONLY what is visible (no JSON commitment to subtasks yet); + its description is injected into the segmentation prompt so + the model segments its own grounded observations instead of + pattern-matching the task text. + 2. segmentation — emit subtask JSON (as before). + """ + if record.row_count == 0 or not record.frame_timestamps: + return [] + episode_duration = record.frame_timestamps[-1] - record.frame_timestamps[0] + effective_task = task if task is not None else record.episode_task + + # ---- Auto-windowing (keeps the full sampling density) -------- + # Contact sheets are cheap, but a whole long episode sampled at + # ``frames_per_second`` can still exceed ``max_frames_per_prompt``. + # When it does, split into consecutive windows of exactly that many + # frames (one describe→segment call each, still at the full sampling + # density), then merge + stitch — so an episode of any length is + # covered at full density rather than subsampled into one sparse call. + fps = max(1e-6, float(self.config.frames_per_second)) + n_whole = int(round(episode_duration * fps)) + 1 + if n_whole > self.config.max_frames_per_prompt: + window_s = self.config.max_frames_per_prompt / fps + return self._generate_subtasks_windowed(record, effective_task, window_s) + + # ---- Pass 1 (optional): grounding description ---------------- + observation_block = "" + if getattr(self.config, "subtask_describe_first", False): + description = self._describe_episode(record, effective_task) + if description: + observation_block = ( + "You watched this video and described, chronologically, " + "ONLY what the robot actually does:\n" + f'"""{description}"""\n\n' + "Segment THAT grounded description (cross-checked against " + "the video) into atomic subtasks. Do not introduce any " + "action that is not in your description above.\n\n" + ) + + # ---- Pass 2: segmentation ------------------------------------ + prompt = self._with_causal_rules( + load_prompt("plan_subtasks").format( + episode_task=effective_task, + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{episode_duration:.3f}", + observation_block=observation_block, + ) + ) + spans = self._vlm_field(self._video_message(record, prompt), "subtasks") + cleaned = self._clean_spans(spans, record) + if not cleaned: + return [] + + # ---- Full-episode coverage stitch ---------------------------- + # The VLM can start after t0 or leave gaps, so frames fall through + # with no active subtask. Always stitch into a contiguous + # [t0, t_last] cover. + cleaned = self._stitch_full_coverage(cleaned, record) + + return cleaned + + def _generate_subtasks_windowed( + self, record: EpisodeRecord, task: str, window_s: float + ) -> list[dict[str, Any]]: + """Subtask generation in fixed-length windows at constant fps. + + Splits ``[t0, t_last]`` into consecutive windows of ``window_s`` + seconds, runs the describe -> segment chain on each window's own + frames (sampled at ``frames_per_second``), offsets + each window's spans back to absolute episode time, then merges + + stitches into a contiguous whole-episode cover. + """ + t0 = float(record.frame_timestamps[0]) + t_last = float(record.frame_timestamps[-1]) + all_spans: list[dict[str, Any]] = [] + w0 = t0 + n_windows = 0 + while w0 < t_last - 1e-6: + w1 = min(w0 + window_s, t_last) + all_spans.extend(self._subtasks_for_window(record, task, w0, w1)) + n_windows += 1 + w0 = w1 + logger.info( + "episode %d: windowed subtask gen over %d window(s) of %.1fs -> %d raw spans", + record.episode_index, + n_windows, + window_s, + len(all_spans), + ) + # Merge across windows: clamp to the absolute episode, sort, and + # frame-snap to distinct starts (handles any boundary collisions). + cleaned = self._clean_spans(all_spans, record) + if not cleaned: + return [] + return self._stitch_full_coverage(cleaned, record) + + def _subtasks_for_window( + self, record: EpisodeRecord, task: str, w0: float, w1: float + ) -> list[dict[str, Any]]: + """Run describe -> segment on one ``[w0, w1]`` window. + + The model works in window-RELATIVE time ``[0, L]`` (it perceives + the window as a clip starting at 0); spans are offset back to + absolute ``[w0, w1]`` before returning. + """ + window = (w0, w1) + win_len = max(0.0, w1 - w0) + + observation_block = "" + if getattr(self.config, "subtask_describe_first", False): + description = self._describe_episode(record, task, window=window) + if description: + observation_block = ( + "You watched this video clip and described, chronologically, " + "ONLY what the robot actually does:\n" + f'"""{description}"""\n\n' + "Segment THAT grounded description (cross-checked against " + "the clip) into atomic subtasks. Do not introduce any " + "action that is not in your description above.\n\n" + ) + + prompt = self._with_causal_rules( + load_prompt("plan_subtasks").format( + episode_task=task, + min_subtask_seconds=self.config.min_subtask_seconds, + max_steps=self.config.plan_max_steps, + episode_duration=f"{win_len:.3f}", + observation_block=observation_block, + ) + ) + spans = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks") + # Window-relative clamp; no frame-snap dedupe yet (done on the + # merged absolute set). + cleaned = self._clean_spans(spans, record, bounds=(0.0, win_len), dedupe=False) + if not cleaned: + return [] + + # Offset window-relative spans back to absolute episode time. + for s in cleaned: + s["start"] = w0 + float(s["start"]) + s["end"] = w0 + float(s["end"]) + return cleaned + + def _stitch_full_coverage( + self, spans: list[dict[str, Any]], record: EpisodeRecord + ) -> list[dict[str, Any]]: + """Make subtask spans tile the full episode with no gaps. + + * The first subtask starts at the episode's first frame ``t0`` + (any idle / approach before the first labelled action is folded + into it), so every early frame has an active subtask. + * Each subtask's ``end`` is snapped to the next subtask's + ``start`` (gaps between spans are closed), and the final + subtask's ``end`` extends to the last frame ``t_last``. + + Starts are otherwise left as the (already frame-snapped, distinct) + values the VLM produced — only the FIRST start is pulled + back to ``t0``, which can't collide with a later span because it + was already the earliest. Purely deterministic; runs after the + VLM passes. + """ + if not spans or not record.frame_timestamps: + return spans + t0 = float(record.frame_timestamps[0]) + t_last = float(record.frame_timestamps[-1]) + spans = sorted(spans, key=lambda s: float(s["start"])) + spans[0]["start"] = t0 + for i in range(len(spans) - 1): + spans[i]["end"] = float(spans[i + 1]["start"]) + spans[-1]["end"] = t_last + for s in spans: + if float(s["end"]) < float(s["start"]): + s["end"] = float(s["start"]) + return spans + + @staticmethod + def _with_causal_rules(prompt: str) -> str: + """Append the causal event-boundary rules to a describe/segment prompt.""" + return f"{prompt}\n\n{_CAUSAL_BOUNDARY_RULES}" + + def _clean_spans( + self, + spans: Any, + record: EpisodeRecord, + bounds: tuple[float, float] | None = None, + dedupe: bool = True, + ) -> list[dict[str, Any]]: + """Clamp / sort / (optionally) dedupe raw VLM subtask spans into valid rows. + + ``bounds`` overrides the clamp range — pass the window's + ``(w_lo, w_hi)`` when cleaning window-relative spans, or leave + ``None`` to clamp to the whole episode ``[t0, t_last]``. + ``dedupe`` runs the frame-snap distinct-start step; skip it for + window-relative spans (frame snapping is done once on the merged, + absolute-time set). + """ + if not spans: + return [] + if bounds is not None: + lo, hi = float(bounds[0]), float(bounds[1]) + else: + lo = record.frame_timestamps[0] + hi = record.frame_timestamps[-1] + cleaned: list[dict[str, Any]] = [] + for span in spans: + try: + start = float(span["start"]) + end = float(span["end"]) + text = str(span["text"]).strip() + except (KeyError, ValueError, TypeError): + continue + start = max(lo, min(start, hi)) + end = max(lo, min(end, hi)) + if end < start: + start, end = end, start + if not text: + continue + cleaned.append({"text": text, "start": start, "end": end}) + cleaned.sort(key=lambda s: s["start"]) + if dedupe: + return self._dedupe_starts_to_distinct_frames(cleaned, record) + return cleaned + + def _describe_episode( + self, record: EpisodeRecord, task: str, window: tuple[float, float] | None = None + ) -> str: + """Grounding pass: free-form chronological description of the (windowed) video.""" + prompt = self._with_causal_rules(load_prompt("plan_subtask_describe").format(episode_task=task)) + text = self._vlm_field(self._video_message(record, prompt, window=window), "description") + return text.strip() if isinstance(text, str) and text.strip() else "" + + @staticmethod + def _dedupe_starts_to_distinct_frames( + spans: list[dict[str, Any]], record: EpisodeRecord + ) -> list[dict[str, Any]]: + """Bump same-frame subtask starts onto distinct frames. + + Two consecutive VLM spans whose ``start`` rounds to the same + source frame (after :func:`snap_to_frame`) would otherwise emit + two ``style=subtask`` rows at the identical persistent + timestamp. The training-time renderer's ``active_at(t, + style=subtask)`` resolver can't disambiguate that and raises + ``Ambiguous resolver for style='subtask'``. + + Walk the (sorted-by-start) spans, snap each to its frame, and + if the snapped frame is already taken push the span onto the + next unused frame so both subtasks survive on distinct + timestamps. If the episode ends before a free frame is found, + the trailing span is dropped with a warning — better than + poisoning the render. + """ + if not spans: + return spans + frames = record.frame_timestamps + if not frames: + return spans + used: set[float] = set() + out: list[dict[str, Any]] = [] + for span in spans: + ts = snap_to_frame(span["start"], frames) + if ts in used: + next_ts = next((f for f in frames if f > ts and f not in used), None) + if next_ts is None: + logger.warning( + "episode %d: subtask %r snapped to occupied frame " + "%.3f and no free later frame exists — dropping", + record.episode_index, + span.get("text"), + ts, + ) + continue + ts = next_ts + used.add(ts) + new_span = {**span, "start": ts} + if float(new_span.get("end", ts)) < ts: + new_span["end"] = ts + out.append(new_span) + return out + + def _generate_plan( + self, + record: EpisodeRecord, # noqa: ARG002 (kept for signature stability) + subtask_spans: Sequence[dict[str, Any]], + *, + refresh_t: float | None = None, + interjection: str | None = None, # noqa: ARG002 + task: str | None = None, # noqa: ARG002 + ) -> str | None: + """Deterministic plan = numbered list of *still-todo* subtasks. + + No VLM call: a plain numbered list keeps the plan aligned with the + upcoming subtasks (the old VLM "compact hierarchical plan" prompt + cost a round-trip per episode/refresh and could diverge). + + 1. + 2. + + On a refresh at ``refresh_t`` (from ``run_plan_updates`` on + interjections, and ``run_episode`` at each boundary), only subtasks + starting at or after ``refresh_t`` are included — so it always + describes what's left. + """ + if not subtask_spans: + return None + remaining = [ + s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t) + ] + if not remaining: + # Past the last subtask boundary on a late refresh — nothing + # left to plan; emit None so the caller skips the row. + return None + return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)) + + def _generate_memory( + self, + record: EpisodeRecord, + prior_memory: str, + completed: str, + remaining: Sequence[str], + *, + task: str | None = None, + ) -> str: + prompt = load_prompt("plan_memory").format( + episode_task=(task if task is not None else record.episode_task), + prior_memory=prior_memory or "(none)", + completed_subtask=completed, + remaining_subtasks=", ".join(remaining) if remaining else "(none)", + ) + memory = self._vlm_field(self._text_message(prompt), "memory") + return memory.strip() if isinstance(memory, str) else "" diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py new file mode 100644 index 000000000..5ce8e163b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/__init__.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prompt templates loaded as plain text. + +One file per use site. Templates use ``str.format(**vars)`` substitution; we +intentionally avoid jinja2 here so the templates remain inspectable in +plain editors and roundtrip cleanly through ``ruff format``. +""" + +from __future__ import annotations + +from pathlib import Path + +_DIR = Path(__file__).parent + + +def load(name: str) -> str: + """Read prompt template ``name.txt`` from the ``prompts/`` directory.""" + path = _DIR / f"{name}.txt" + return path.read_text(encoding="utf-8") diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt new file mode 100644 index 000000000..625ce920c --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_initial_speech.txt @@ -0,0 +1,12 @@ +The user just asked the robot: "{episode_task}". + +Generate a short verbal acknowledgement the robot would speak back before +beginning the task. Style: compact, confident, friendly. + +Examples (Hi Robot, Shi 2025): "Sure, I won't put cheese on it.", +"OK, starting with the sponge.", "Got it.". + +Prefer very short replies: "Got it.", "On it.", "OK." + +Output strictly valid JSON: + {{ "text": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt new file mode 100644 index 000000000..4a4719f54 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/interjections_interjection.txt @@ -0,0 +1,46 @@ +You are generating training data for a Hi Robot-style hierarchical +robot policy. The robot in this demonstration has ALREADY executed +every step shown in the video — we cannot retroactively change the +action stream. To keep training data consistent with the video, the +"interjection" must align with what the robot is *about to do next* in +the demonstration, framed as a natural mid-task user request. + +The episode's overall task: "{episode_task}". + +The images above show roughly {window_seconds:.1f} seconds straddling a +subtask boundary in the demonstration: + +- Subtask the robot just finished: "{prev_subtask}" +- Subtask the robot is about to start: "{next_subtask}" +- Time into episode: {timestamp:.2f}s + +Write ONE compact interjection the user would naturally say at this +moment to prompt / confirm / encourage the robot to do "{next_subtask}". +Keep it like a mid-task coaching cue, not a full instruction paragraph. +Also write the robot's compact verbal acknowledgement. + +Hard rules: + +- The interjection MUST be consistent with the next subtask. The user + cannot ask for something different from what the robot then does in + the video. If you're tempted to say "actually skip X" or "do Y + instead", DO NOT — those would contradict the demonstration. +- The interjection must reference an object, location, or action that + is plausible given the visible scene and the next subtask text. +- One short phrase or sentence each. Conversational, not robotic. +- Prefer direct cues: "{next_subtask}, please."; "Now {next_subtask}." +- Keep robot speech very short: "OK.", "On it.", "Doing that." + +Style examples (vary the phrasing — don't reuse these verbatim): + - "Now go ahead and {next_subtask}." + - "Great, can you {next_subtask} next?" + - "{next_subtask}, please." + - "Before you continue, please {next_subtask}." + - "Looking good — {next_subtask} now." + - "Okay, {next_subtask}." + +Output strictly valid JSON: + {{ + "interjection": "", + "speech": "" + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt new file mode 100644 index 000000000..b5278368b --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_memory.txt @@ -0,0 +1,36 @@ +You are updating the robot's compressed semantic memory at the boundary of +a completed subtask. + +Reference (verbatim from MEM, Torne 2026): +"Remove or compress information in the language memory whenever +appropriate. Keep ONLY the minimal set of relevant information for future +task execution. Specific object attributes (colors, precise quantities of +each item) get discarded when their details won't affect subsequent +actions. Functional outcomes (where items went, how many) are preserved." + +Episode task: "{episode_task}" +Previous memory: {prior_memory} +Just-completed subtask: "{completed_subtask}" +Remaining subtasks (for relevance judgement only): {remaining_subtasks} + +Write the memory as a short FIRST-PERSON, PAST-TENSE narrative of what the +robot has accomplished so far — the running story it would tell itself. + +Authoring rules: +- First person, past tense. Every sentence starts with "I": "I picked + up...", "I opened...", "I moved to...". +- One or two short sentences. Extend the previous memory with the + just-completed subtask; do not rewrite it from scratch. +- Keep WHAT happened (functional outcomes — where items went, how many), + drop HOW (grasp details, motions). +- Compress completed steps and drop object attributes (colors, exact + counts) once they no longer affect the remaining subtasks. + +Example (MEM, Torne 2026): + Before: "I prepared the pot and got the potatoes, milk, and butter. I + moved to the drawer." + After: "I prepared the pot and got the ingredients. I opened the + drawer with the masher." + +Output strictly valid JSON: + {{ "memory": "" }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt new file mode 100644 index 000000000..6b709e41d --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtask_describe.txt @@ -0,0 +1,27 @@ +You are watching a teleoperated robot demonstration from a single +camera. The user asked the robot to: "{episode_task}" + +This is an OBSERVATION pass. Watch the entire clip and describe, in +chronological order, ONLY what the robot physically does — the concrete +motions, approaches, contacts, grasps, releases, and relocations you can +actually SEE in the frames. + +Hard rules: +- Describe only motion visible in the video. Do NOT use the task + instruction to guess steps that aren't shown. The instruction is the + goal; the video is ground truth. +- Do NOT segment into named subtasks yet and do NOT output JSON beyond + the single field below. Just narrate what happens. +- Give an approximate timestamp (in seconds) for each distinct event, + e.g. "0.0-1.4s: the base drives forward toward the stove". +- Do NOT invent objects, grasps, destinations, or steps. If the robot + only does one thing (e.g. it just navigates and the clip ends), say + exactly that and nothing more. +- Be concrete and literal. "the gripper closes on the mug" — not "the + robot prepares to make coffee". + +Output strictly valid JSON: + + {{ + "description": "" + }} diff --git a/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt new file mode 100644 index 000000000..e6a5260a7 --- /dev/null +++ b/src/lerobot/annotations/steerable_pipeline/prompts/plan_subtasks.txt @@ -0,0 +1,112 @@ +You are labeling a teleoperated robot demonstration. + +The user originally asked: "{episode_task}" + +You are shown the entire demonstration as a single video. Watch the +whole clip, then segment it into a list of consecutive atomic subtasks +the robot performs. + +{observation_block}GROUNDING — read this first, it overrides everything below: +- Label ONLY what the robot actually does in the video. Every subtask + you emit must correspond to motion you can SEE in specific frames. +- Do NOT invent, anticipate, or pad. If the robot only does one thing + (e.g. it just navigates to a location and the clip ends), emit + EXACTLY ONE subtask. Many demonstrations are a single atomic skill. +- ``max_steps`` below is a hard CEILING, not a target. Emitting fewer + subtasks than the ceiling is not just allowed, it is expected for + short / atomic demonstrations. One correct subtask is far better + than several invented ones. +- If the video does not clearly show the action implied by the task, + describe what you actually see — do NOT fabricate the task's steps + from the instruction text. The instruction tells you the goal; the + VIDEO is the ground truth for what happened. + +Authoring rules — Hi Robot atom granularity, pi0.7-style short prompts: + +- Each subtask = one COMPOSITE atomic skill the low-level policy can + execute end-to-end. A "skill" bundles its own approach motion with + its terminal action — do NOT split the approach off as its own + subtask. The whole-arm policy already learns to reach as part of + every manipulation primitive. +- Write each subtask as an IMPERATIVE COMMAND, starting with one of + these verbs (extend only when none fits): + pick up — approach + grasp + lift in one subtask + put on/in — transport + release in one subtask + place on/in — synonym of "put"; pick one and stay consistent + push — contact + linear shove + pull — contact + linear retract + turn — rotary actuation + press