mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
chore(docs): add more guidance to bring your own policies tutorial (#3230)
* chore(docs): add more guidance to bring your own policies tutorial * removing normalization to avoid confusion with processors * trailing whitespace * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * Update docs/source/bring_your_own_policies.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net> * adding get optim params and predict_action chunk * removing extra quotes --------- Signed-off-by: Maxime Ellerbach <maxime@ellerbach.net>
This commit is contained in:
@@ -41,13 +41,15 @@ requires = # your-build-system
|
|||||||
|
|
||||||
## Step 2: Define the Policy Configuration
|
## Step 2: Define the Policy Configuration
|
||||||
|
|
||||||
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
Create a configuration class that inherits from [`PreTrainedConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/configs/policies.py) and registers your policy type:
|
||||||
|
Here is a template to get you started, customize the parameters and methods as needed for your policy's architecture and training requirements.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# configuration_my_custom_policy.py
|
# configuration_my_custom_policy.py
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import NormalizationMode
|
from lerobot.optim.optimizers import AdamWConfig
|
||||||
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -61,22 +63,56 @@ class MyCustomPolicyConfig(PreTrainedConfig):
|
|||||||
hidden_dim: Hidden dimension for the policy network
|
hidden_dim: Hidden dimension for the policy network
|
||||||
# Add your policy-specific parameters here
|
# Add your policy-specific parameters here
|
||||||
"""
|
"""
|
||||||
# ...PreTrainedConfig fields...
|
|
||||||
pass
|
horizon: int = 50
|
||||||
|
n_action_steps: int = 50
|
||||||
|
hidden_dim: int = 256
|
||||||
|
|
||||||
|
optimizer_lr: float = 1e-4
|
||||||
|
optimizer_weight_decay: float = 1e-4
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
# Add any validation logic here
|
if self.n_action_steps > self.horizon:
|
||||||
|
raise ValueError("n_action_steps cannot exceed horizon")
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
def validate_features(self) -> None:
|
||||||
"""Validate input/output feature compatibility."""
|
"""Validate input/output feature compatibility."""
|
||||||
# Implement validation logic for your policy's requirements
|
if not self.image_features:
|
||||||
pass
|
raise ValueError("MyCustomPolicy requires at least one image feature.")
|
||||||
|
if self.action_feature is None:
|
||||||
|
raise ValueError("MyCustomPolicy requires 'action' in output_features.")
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> AdamWConfig:
|
||||||
|
return AdamWConfig(lr=self.optimizer_lr, weight_decay=self.optimizer_weight_decay)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list[int] | None:
|
||||||
|
"""Relative timestep offsets the dataset loader provides per observation.
|
||||||
|
|
||||||
|
Return `None` for single-frame policies. For temporal policies that consume
|
||||||
|
multiple past or future frames, return a list of offsets, e.g. `[-20, -10, 0, 10]` for
|
||||||
|
3 past frames at stride 10 and 1 future frame at stride 10.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list[int]:
|
||||||
|
"""Relative timestep offsets for the action chunk the dataset loader returns.
|
||||||
|
"""
|
||||||
|
return list(range(self.horizon))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 3: Implement the Policy Class
|
## Step 3: Implement the Policy Class
|
||||||
|
|
||||||
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
Create your policy implementation by inheriting from [`PreTrainedPolicy`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pretrained.py):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# modeling_my_custom_policy.py
|
# modeling_my_custom_policy.py
|
||||||
@@ -85,38 +121,74 @@ import torch.nn as nn
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||||
|
|
||||||
class MyCustomPolicy(PreTrainedPolicy):
|
class MyCustomPolicy(PreTrainedPolicy):
|
||||||
config_class = MyCustomPolicyConfig
|
config_class = MyCustomPolicyConfig # must match the string in @register_subclass
|
||||||
name = "my_custom_policy"
|
name = "my_custom_policy"
|
||||||
|
|
||||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: dict[str, Any] = None):
|
||||||
super().__init__(config, dataset_stats)
|
super().__init__(config, dataset_stats)
|
||||||
|
config.validate_features() # not called automatically by the base class
|
||||||
|
self.config = config
|
||||||
|
self.model = ... # your nn.Module here
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset episode state."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def get_optim_params(self) -> dict:
|
||||||
|
"""Return parameters to pass to the optimizer (e.g. with per-group lr/wd)."""
|
||||||
|
return {"params": self.parameters()}
|
||||||
|
|
||||||
|
def predict_action_chunk(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||||
|
"""Return the full action chunk (B, chunk_size, action_dim) for the current observation."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def select_action(self, batch: dict[str, torch.Tensor], **kwargs) -> torch.Tensor:
|
||||||
|
"""Return a single action for the current timestep (called at inference)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
|
"""Compute the training loss.
|
||||||
|
|
||||||
|
`batch["action_is_pad"]` is a bool mask of shape (B, horizon) that marks
|
||||||
|
timesteps padded because the episode ended before `horizon` steps, you
|
||||||
|
can exclude those from your loss.
|
||||||
|
"""
|
||||||
|
actions = batch[ACTION]
|
||||||
|
action_is_pad = batch.get("action_is_pad")
|
||||||
|
...
|
||||||
|
return {"loss": ...}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Step 4: Add Data Processors
|
## Step 4: Add Data Processors
|
||||||
|
|
||||||
Create processor functions:
|
Create processor functions. For a concrete reference, see [processor_act.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/processor_act.py) or [processor_diffusion.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/diffusion/processor_diffusion.py).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# processor_my_custom_policy.py
|
# processor_my_custom_policy.py
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||||
|
|
||||||
|
|
||||||
def make_my_custom_policy_pre_post_processors(
|
def make_my_custom_policy_pre_post_processors(
|
||||||
config,
|
config,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
]:
|
]:
|
||||||
"""Create preprocessing and postprocessing functions for your policy."""
|
preprocessor = ... # build your PolicyProcessorPipeline for inputs
|
||||||
pass # Define your preprocessing and postprocessing logic here
|
postprocessor = ... # build your PolicyProcessorPipeline for outputs
|
||||||
|
return preprocessor, postprocessor
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Important - function naming:** LeRobot discovers your processor by name. The function **must** be called `make_{policy_name}_pre_post_processors` (matching the string you passed to `@PreTrainedConfig.register_subclass`).
|
||||||
|
|
||||||
## Step 5: Package Initialization
|
## Step 5: Package Initialization
|
||||||
|
|
||||||
Expose your classes in the package's `__init__.py`:
|
Expose your classes in the package's `__init__.py`:
|
||||||
|
|||||||
Reference in New Issue
Block a user