mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ca0087d6da | |||
| e3ce2eb743 | |||
| 17f4bc4c56 |
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
horizon: int = 16
|
chunk_size: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# The original implementation doesn't sample frames for the last 7 steps,
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
# which avoids excessive padding and leads to improved training results.
|
# which avoids excessive padding and leads to improved training results.
|
||||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
f"Got {self.noise_scheduler_type}."
|
f"Got {self.noise_scheduler_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the horizon size and U-Net downsampling is compatible.
|
# Check that the chunk size and U-Net downsampling is compatible.
|
||||||
# U-Net downsamples by 2 with each stage.
|
# U-Net downsamples by 2 with each stage.
|
||||||
downsampling_factor = 2 ** len(self.down_dims)
|
downsampling_factor = 2 ** len(self.down_dims)
|
||||||
if self.horizon % downsampling_factor != 0:
|
if self.chunk_size % downsampling_factor != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
||||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamConfig:
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method handles caching a history of observations and an action trajectory generated by the
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
underlying diffusion model. Here's how it works:
|
underlying diffusion model. Here's how it works:
|
||||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
copied `n_obs_steps` times to fill the cache).
|
copied `n_obs_steps` times to fill the cache).
|
||||||
- The diffusion model generates `horizon` steps worth of actions.
|
- The diffusion model generates `chunk_size` steps worth of actions.
|
||||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
Schematically this looks like:
|
Schematically this looks like:
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
||||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
||||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
|
|||||||
noise
|
noise
|
||||||
if noise is not None
|
if noise is not None
|
||||||
else torch.randn(
|
else torch.randn(
|
||||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
|
|||||||
AND/OR
|
AND/OR
|
||||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, chunk_size, action_dim)
|
||||||
"action_is_pad": (B, horizon)
|
"action_is_pad": (B, chunk_size)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||||
horizon = batch[ACTION].shape[1]
|
chunk_size = batch[ACTION].shape[1]
|
||||||
assert horizon == self.config.horizon
|
assert chunk_size == self.config.chunk_size
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
# Encode image features and concatenate them all together along with the state vector.
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
|||||||
@@ -0,0 +1,242 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
|
from lerobot.configs.types import NormalizationMode
|
||||||
|
from lerobot.optim.optimizers import MultiAdamConfig
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||||
|
|
||||||
|
|
||||||
|
def is_image_feature(key: str) -> bool:
|
||||||
|
"""Check if a feature key represents an image feature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The feature key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key represents an image feature, False otherwise
|
||||||
|
"""
|
||||||
|
return key.startswith(OBS_IMAGE)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConcurrencyConfig:
|
||||||
|
"""Configuration for the concurrency of the actor and learner.
|
||||||
|
Possible values are:
|
||||||
|
- "threads": Use threads for the actor and learner.
|
||||||
|
- "processes": Use processes for the actor and learner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
actor: str = "threads"
|
||||||
|
learner: str = "threads"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorLearnerConfig:
|
||||||
|
learner_host: str = "127.0.0.1"
|
||||||
|
learner_port: int = 50051
|
||||||
|
policy_parameters_push_frequency: int = 4
|
||||||
|
queue_get_timeout: float = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CriticNetworkConfig:
|
||||||
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
activate_final: bool = True
|
||||||
|
final_activation: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActorNetworkConfig:
|
||||||
|
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||||
|
activate_final: bool = True
|
||||||
|
use_layer_norm: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NoiseActorConfig:
|
||||||
|
"""Configuration for the noise actor in DSRL.
|
||||||
|
The noise actor outputs noise that gets fed to the diffusion policy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_tanh_squash: bool = False # Whether to bound the noise output
|
||||||
|
std_min: float = 1e-5
|
||||||
|
std_max: float = 2.0
|
||||||
|
init_final: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
@PreTrainedConfig.register_subclass("dsrl")
|
||||||
|
@dataclass
|
||||||
|
class DSRLConfig(PreTrainedConfig):
|
||||||
|
"""Diffusion Steering via Reinforcement Learning (DSRL) configuration."""
|
||||||
|
|
||||||
|
# Mapping of feature types to normalization modes
|
||||||
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"VISUAL": NormalizationMode.MEAN_STD,
|
||||||
|
"STATE": NormalizationMode.MIN_MAX,
|
||||||
|
"ENV": NormalizationMode.MIN_MAX,
|
||||||
|
"ACTION": NormalizationMode.MIN_MAX,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Statistics for normalizing different types of inputs
|
||||||
|
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
OBS_IMAGE: {
|
||||||
|
"mean": [0.485, 0.456, 0.406],
|
||||||
|
"std": [0.229, 0.224, 0.225],
|
||||||
|
},
|
||||||
|
OBS_STATE: {
|
||||||
|
"min": [0.0, 0.0],
|
||||||
|
"max": [1.0, 1.0],
|
||||||
|
},
|
||||||
|
ACTION: {
|
||||||
|
"min": [0.0, 0.0, 0.0],
|
||||||
|
"max": [1.0, 1.0, 1.0],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Architecture specifics
|
||||||
|
# Device to run the model on (e.g., "cuda", "cpu")
|
||||||
|
device: str = "cpu"
|
||||||
|
# Device to store the model on
|
||||||
|
storage_device: str = "cpu"
|
||||||
|
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||||
|
vision_encoder_name: str | None = None
|
||||||
|
# Whether to freeze the vision encoder during training
|
||||||
|
freeze_vision_encoder: bool = True
|
||||||
|
# Hidden dimension size for the image encoder
|
||||||
|
image_encoder_hidden_dim: int = 32
|
||||||
|
# Whether to use a shared encoder for actor and critic
|
||||||
|
shared_encoder: bool = True
|
||||||
|
# Number of discrete actions, eg for gripper actions
|
||||||
|
num_discrete_actions: int | None = None
|
||||||
|
# Dimension of the image embedding pooling
|
||||||
|
image_embedding_pooling_dim: int = 8
|
||||||
|
|
||||||
|
# Name of the action policy
|
||||||
|
action_policy_name: str = "pi0"
|
||||||
|
action_policy_weights: str | None = "lerobot/pi0_base"
|
||||||
|
|
||||||
|
# Training parameter
|
||||||
|
# Number of steps for online training
|
||||||
|
online_steps: int = 1000000
|
||||||
|
# Capacity of the online replay buffer
|
||||||
|
online_buffer_capacity: int = 100000
|
||||||
|
# Capacity of the offline replay buffer
|
||||||
|
offline_buffer_capacity: int = 100000
|
||||||
|
# Whether to use asynchronous prefetching for the buffers
|
||||||
|
async_prefetch: bool = False
|
||||||
|
# Number of steps before learning starts
|
||||||
|
online_step_before_learning: int = 100
|
||||||
|
# Frequency of policy updates
|
||||||
|
policy_update_freq: int = 1
|
||||||
|
|
||||||
|
# SAC algorithm parameters
|
||||||
|
discount: float = 0.99
|
||||||
|
# Initial temperature value
|
||||||
|
temperature_init: float = 1.0
|
||||||
|
# Number of critics in the ensemble
|
||||||
|
num_critics: int = 2
|
||||||
|
# Number of subsampled critics for training
|
||||||
|
num_subsample_critics: int | None = None
|
||||||
|
# Learning rate for the critic network
|
||||||
|
critic_lr: float = 3e-4
|
||||||
|
# Learning rate for the actor network
|
||||||
|
actor_lr: float = 3e-4
|
||||||
|
# Learning rate for the temperature parameter
|
||||||
|
temperature_lr: float = 3e-4
|
||||||
|
# Weight for the critic target update
|
||||||
|
critic_target_update_weight: float = 0.005
|
||||||
|
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||||
|
utd_ratio: int = 1
|
||||||
|
# Hidden dimension size for the state encoder
|
||||||
|
state_encoder_hidden_dim: int = 256
|
||||||
|
# Dimension of the latent space
|
||||||
|
latent_dim: int = 256
|
||||||
|
# Target entropy for the SAC algorithm
|
||||||
|
target_entropy: float | None = None
|
||||||
|
# Whether to use backup entropy for the SAC algorithm
|
||||||
|
use_backup_entropy: bool = True
|
||||||
|
# Gradient clipping norm for the SAC algorithm
|
||||||
|
grad_clip_norm: float = 40.0
|
||||||
|
|
||||||
|
# Network configuration
|
||||||
|
# Configuration for the critic network architecture
|
||||||
|
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
# Configuration for the noise critic network architecture
|
||||||
|
noise_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||||
|
# Configuration for the noise actor network architecture
|
||||||
|
noise_actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||||
|
# Configuration for the noise actor specific parameters
|
||||||
|
noise_actor_kwargs: NoiseActorConfig = field(default_factory=NoiseActorConfig)
|
||||||
|
# Configuration for actor-learner architecture
|
||||||
|
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||||
|
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||||
|
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||||
|
|
||||||
|
# Optimizations
|
||||||
|
use_torch_compile: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||||
|
return MultiAdamConfig(
|
||||||
|
weight_decay=0.0,
|
||||||
|
optimizer_groups={
|
||||||
|
"critic_action": {"lr": self.critic_lr},
|
||||||
|
"critic_noise": {"lr": self.critic_lr},
|
||||||
|
"noise_actor": {"lr": self.actor_lr},
|
||||||
|
"temperature": {"lr": self.temperature_lr},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_scheduler_preset(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate_features(self) -> None:
|
||||||
|
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||||
|
has_state = OBS_STATE in self.input_features
|
||||||
|
|
||||||
|
if not (has_state or has_image):
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||||
|
)
|
||||||
|
|
||||||
|
if ACTION not in self.output_features:
|
||||||
|
raise ValueError("You must provide 'action' in the output features")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_features(self) -> list[str]:
|
||||||
|
return [key for key in self.input_features if is_image_feature(key)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def observation_delta_indices(self) -> list:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_delta_indices(self) -> list:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reward_delta_indices(self) -> None:
|
||||||
|
return None
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,89 @@
|
|||||||
|
# !/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Processor for DSRL policy.
|
||||||
|
|
||||||
|
DSRL uses a similar processing pipeline as SAC since it operates on
|
||||||
|
state-action transitions. The main difference is that internally it
|
||||||
|
also works with noise, but that's handled within the policy itself.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||||
|
from lerobot.processor import (
|
||||||
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeviceProcessorStep,
|
||||||
|
NormalizerProcessorStep,
|
||||||
|
PolicyAction,
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
RenameObservationsProcessorStep,
|
||||||
|
UnnormalizerProcessorStep,
|
||||||
|
)
|
||||||
|
from lerobot.processor.converters import (
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_policy_action,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def make_dsrl_pre_post_processors(
|
||||||
|
config: DSRLConfig,
|
||||||
|
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict, dict],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
|
"""Create preprocessor and postprocessor pipelines for DSRL policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: DSRL policy configuration
|
||||||
|
dataset_stats: Optional dataset statistics for normalization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (preprocessor, postprocessor) pipelines
|
||||||
|
"""
|
||||||
|
input_steps = [
|
||||||
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
|
AddBatchDimensionProcessorStep(),
|
||||||
|
DeviceProcessorStep(device=config.device),
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features={**config.input_features, **config.output_features},
|
||||||
|
norm_map=config.normalization_mapping,
|
||||||
|
stats=dataset_stats,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
output_steps = [
|
||||||
|
UnnormalizerProcessorStep(
|
||||||
|
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||||
|
),
|
||||||
|
DeviceProcessorStep(device="cpu"),
|
||||||
|
]
|
||||||
|
return (
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||||
|
steps=input_steps,
|
||||||
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
|
),
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||||
|
steps=output_steps,
|
||||||
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
|
|||||||
from lerobot.envs.utils import env_to_policy_features
|
from lerobot.envs.utils import env_to_policy_features
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
|
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
@@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "dsrl".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
|
elif name == "dsrl":
|
||||||
|
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
|
||||||
|
|
||||||
|
return DSRLPolicy
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -121,7 +126,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||||
"reward_classifier".
|
"reward_classifier", "dsrl".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -148,6 +153,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return SmolVLAConfig(**kwargs)
|
return SmolVLAConfig(**kwargs)
|
||||||
elif policy_type == "reward_classifier":
|
elif policy_type == "reward_classifier":
|
||||||
return RewardClassifierConfig(**kwargs)
|
return RewardClassifierConfig(**kwargs)
|
||||||
|
elif policy_type == "dsrl":
|
||||||
|
return DSRLConfig(**kwargs)
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -321,6 +328,21 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
|
elif isinstance(policy_cfg, DSRLConfig):
|
||||||
|
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_dsrl_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|
||||||
|
processors = make_groot_pre_post_processors(
|
||||||
|
config=policy_cfg,
|
||||||
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
|
)
|
||||||
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|||||||
@@ -1148,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1158,7 +1158,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
|
|
||||||
# Sample actions using the model
|
# Sample actions using the model
|
||||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -1120,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1129,7 +1129,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
|
||||||
# Sample actions using the model (no separate state needed for PI05)
|
# Sample actions using the model (no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
actions = self.model.sample_actions(images, img_masks, tokens, masks, noise)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user