diff --git a/docs/source/bring_your_own_policies.mdx b/docs/source/bring_your_own_policies.mdx index 9266c9e5b..38c32aa71 100644 --- a/docs/source/bring_your_own_policies.mdx +++ b/docs/source/bring_your_own_policies.mdx @@ -41,13 +41,15 @@ requires = # your-build-system ## Step 2: Define the Policy Configuration -Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type: +Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type: +Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements. ```python # configuration_my_custom_policy.py from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import NormalizationMode +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig @PreTrainedConfig.register_subclass("my_custom_policy") @dataclass @@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig): hidden_dim: Hidden dimension for the policy network # Add your policy-specific parameters here """ - # ...PreTrainedConfig fields... - pass + + horizon: int = 50 + n_action_steps: int = 50 + hidden_dim: int = 256 + + optimizer_lr: float = 1e-4 + optimizer_weight_decay: float = 1e-4 def __post_init__(self): super().__post_init__() - # Add any validation logic here + if self.n_action_steps > self.horizon: + raise ValueError("n_action_steps cannot exceed horizon") def validate_features(self) -> None: """Validate input/output feature compatibility.""" - # Implement validation logic for your policy's requirements - pass + if not self.image_features: + raise ValueError("MyCustomPolicy requires at least one image feature.") + if self.action_feature is None: + raise ValueError("MyCustomPolicy requires 'action' in output_features.") + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay) + + def get_scheduler_preset(self): + return None + + @property + def observation_delta_indices(self) -> list[int] | None: + """Relative timestep offsets the dataset loader provides per observation. + + Return `None` for single-frame policies. For temporal policies that consume + multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for + 3 past frames at stride 10 and 1 future frame at stride 10. + """ + return None + + @property + def action_delta_indices(self) -> list[int]: + """Relative timestep offsets for the action chunk the dataset loader returns. + """ + return list(range(self.horizon)) + + @property + def reward_delta_indices(self) -> None: + return None ``` ## Step 3: Implement the Policy Class -Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class: +Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py): ```python # modeling_my_custom_policy.py @@ -85,38 +121,74 @@ import torch.nn as nn from typing import Any from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION from .configuration_my_custom_policy import MyCustomPolicyConfig class MyCustomPolicy(PreTrainedPolicy): - config_class = MyCustomPolicyConfig + config_class = MyCustomPolicyConfig # must match the string in @register_subclass name = "my_custom_policy" def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None): super().__init__(config, dataset_stats) + config.validate_features() # not called automatically by the base class + self.config = config + self.model = ... # your nn.Module here + + def reset(self): + """Reset episode state.""" ... + + def get_optim_params(self) -> dict: + """Return parameters to pass to the optimizer (e.g. with per-group lr/wd).""" + return {"params": self.parameters()} + + def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return the full action chunk (B, chunk_size, action_dim) for the current observation.""" + ... + + def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor: + """Return a single action for the current timestep (called at inference).""" + ... + + def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Compute the training loss. + + `batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks + timesteps padded because the episode ended before `horizon` steps, you + can exclude those from your loss. + """ + actions = batch[ACTION] + action_is_pad = batch.get("action_is_pad") + ... + return {"loss": ...} ``` ## Step 4: Add Data Processors -Create processor functions: +Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py). ```python # processor_my_custom_policy.py from typing import Any import torch +from lerobot.processor import PolicyAction, PolicyProcessorPipeline + def make_my_custom_policy_pre_post_processors( config, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, ) -> tuple[ PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], PolicyProcessorPipeline[PolicyAction, PolicyAction], ]: - """Create preprocessing and postprocessing functions for your policy.""" - pass # Define your preprocessing and postprocessing logic here - + preprocessor = ... # build your PolicyProcessorPipeline for inputs + postprocessor = ... # build your PolicyProcessorPipeline for outputs + return preprocessor, postprocessor ``` +**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`). + ## Step 5: Package Initialization Expose your classes in the package's `__init__.py`: