mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
3 Commits
main
...
feat/add_dsrl
| Author | SHA1 | Date | |
|---|---|---|---|
| ca0087d6da | |||
| e3ce2eb743 | |||
| 17f4bc4c56 |
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# Inputs / output structure.
|
||||
n_obs_steps: int = 2
|
||||
horizon: int = 16
|
||||
chunk_size: int = 16
|
||||
n_action_steps: int = 8
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
||||
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# Check that the chunk size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
if self.horizon % downsampling_factor != 0:
|
||||
if self.chunk_size % downsampling_factor != 0:
|
||||
raise ValueError(
|
||||
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
||||
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamConfig:
|
||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
underlying diffusion model. Here's how it works:
|
||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||
copied `n_obs_steps` times to fill the cache).
|
||||
- The diffusion model generates `horizon` steps worth of actions.
|
||||
- The diffusion model generates `chunk_size` steps worth of actions.
|
||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||
Schematically this looks like:
|
||||
----------------------------------------------------------------------------------------------
|
||||
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||
----------------------------------------------------------------------------------------------
|
||||
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
||||
this period is
|
||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||
"""
|
||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
|
||||
noise
|
||||
if noise is not None
|
||||
else torch.randn(
|
||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
|
||||
AND/OR
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
"action": (B, chunk_size, action_dim)
|
||||
"action_is_pad": (B, chunk_size)
|
||||
}
|
||||
"""
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||
horizon = batch[ACTION].shape[1]
|
||||
assert horizon == self.config.horizon
|
||||
chunk_size = batch[ACTION].shape[1]
|
||||
assert chunk_size == self.config.chunk_size
|
||||
assert n_obs_steps == self.config.n_obs_steps
|
||||
|
||||
# Encode image features and concatenate them all together along with the state vector.
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
def is_image_feature(key: str) -> bool:
|
||||
"""Check if a feature key represents an image feature.
|
||||
|
||||
Args:
|
||||
key: The feature key to check
|
||||
|
||||
Returns:
|
||||
True if the key represents an image feature, False otherwise
|
||||
"""
|
||||
return key.startswith(OBS_IMAGE)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
"""Configuration for the concurrency of the actor and learner.
|
||||
Possible values are:
|
||||
- "threads": Use threads for the actor and learner.
|
||||
- "processes": Use processes for the actor and learner.
|
||||
"""
|
||||
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
queue_get_timeout: float = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
use_layer_norm: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoiseActorConfig:
|
||||
"""Configuration for the noise actor in DSRL.
|
||||
The noise actor outputs noise that gets fed to the diffusion policy.
|
||||
"""
|
||||
|
||||
use_tanh_squash: bool = False # Whether to bound the noise output
|
||||
std_min: float = 1e-5
|
||||
std_max: float = 2.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("dsrl")
|
||||
@dataclass
|
||||
class DSRLConfig(PreTrainedConfig):
|
||||
"""Diffusion Steering via Reinforcement Learning (DSRL) configuration."""
|
||||
|
||||
# Mapping of feature types to normalization modes
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
# Statistics for normalizing different types of inputs
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
ACTION: {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
# Device to run the model on (e.g., "cuda", "cpu")
|
||||
device: str = "cpu"
|
||||
# Device to store the model on
|
||||
storage_device: str = "cpu"
|
||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
||||
vision_encoder_name: str | None = None
|
||||
# Whether to freeze the vision encoder during training
|
||||
freeze_vision_encoder: bool = True
|
||||
# Hidden dimension size for the image encoder
|
||||
image_encoder_hidden_dim: int = 32
|
||||
# Whether to use a shared encoder for actor and critic
|
||||
shared_encoder: bool = True
|
||||
# Number of discrete actions, eg for gripper actions
|
||||
num_discrete_actions: int | None = None
|
||||
# Dimension of the image embedding pooling
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Name of the action policy
|
||||
action_policy_name: str = "pi0"
|
||||
action_policy_weights: str | None = "lerobot/pi0_base"
|
||||
|
||||
# Training parameter
|
||||
# Number of steps for online training
|
||||
online_steps: int = 1000000
|
||||
# Capacity of the online replay buffer
|
||||
online_buffer_capacity: int = 100000
|
||||
# Capacity of the offline replay buffer
|
||||
offline_buffer_capacity: int = 100000
|
||||
# Whether to use asynchronous prefetching for the buffers
|
||||
async_prefetch: bool = False
|
||||
# Number of steps before learning starts
|
||||
online_step_before_learning: int = 100
|
||||
# Frequency of policy updates
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
# Initial temperature value
|
||||
temperature_init: float = 1.0
|
||||
# Number of critics in the ensemble
|
||||
num_critics: int = 2
|
||||
# Number of subsampled critics for training
|
||||
num_subsample_critics: int | None = None
|
||||
# Learning rate for the critic network
|
||||
critic_lr: float = 3e-4
|
||||
# Learning rate for the actor network
|
||||
actor_lr: float = 3e-4
|
||||
# Learning rate for the temperature parameter
|
||||
temperature_lr: float = 3e-4
|
||||
# Weight for the critic target update
|
||||
critic_target_update_weight: float = 0.005
|
||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
||||
utd_ratio: int = 1
|
||||
# Hidden dimension size for the state encoder
|
||||
state_encoder_hidden_dim: int = 256
|
||||
# Dimension of the latent space
|
||||
latent_dim: int = 256
|
||||
# Target entropy for the SAC algorithm
|
||||
target_entropy: float | None = None
|
||||
# Whether to use backup entropy for the SAC algorithm
|
||||
use_backup_entropy: bool = True
|
||||
# Gradient clipping norm for the SAC algorithm
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
# Configuration for the critic network architecture
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the noise critic network architecture
|
||||
noise_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
# Configuration for the noise actor network architecture
|
||||
noise_actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
# Configuration for the noise actor specific parameters
|
||||
noise_actor_kwargs: NoiseActorConfig = field(default_factory=NoiseActorConfig)
|
||||
# Configuration for actor-learner architecture
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
# Optimizations
|
||||
use_torch_compile: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"critic_action": {"lr": self.critic_lr},
|
||||
"critic_noise": {"lr": self.critic_lr},
|
||||
"noise_actor": {"lr": self.actor_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
||||
has_state = OBS_STATE in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if is_image_feature(key)]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,89 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Processor for DSRL policy.
|
||||
|
||||
DSRL uses a similar processing pipeline as SAC since it operates on
|
||||
state-action transitions. The main difference is that internally it
|
||||
also works with noise, but that's handled within the policy itself.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
RenameObservationsProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_dsrl_pre_post_processors(
|
||||
config: DSRLConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict, dict],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessor and postprocessor pipelines for DSRL policy.
|
||||
|
||||
Args:
|
||||
config: DSRL policy configuration
|
||||
dataset_stats: Optional dataset statistics for normalization
|
||||
|
||||
Returns:
|
||||
Tuple of (preprocessor, postprocessor) pipelines
|
||||
"""
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.utils import env_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
@@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "dsrl".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "dsrl":
|
||||
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
|
||||
|
||||
return DSRLPolicy
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -121,7 +126,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
"reward_classifier", "dsrl".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -148,6 +153,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return SmolVLAConfig(**kwargs)
|
||||
elif policy_type == "reward_classifier":
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "dsrl":
|
||||
return DSRLConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
else:
|
||||
@@ -321,6 +328,21 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, DSRLConfig):
|
||||
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
|
||||
|
||||
processors = make_dsrl_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
processors = make_groot_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1158,7 +1158,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, noise)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
Reference in New Issue
Block a user