Compare commits

...

4 Commits

Author SHA1 Message Date
Michel Aractingi e5e1c97a6c pi0 hack 2025-10-30 14:17:36 +01:00
Michel Aractingi 1594ae60a7 * Change Diffusion policy to use chunk_size notation instead of horizon to standerize the variable names across policies
* reshape noise after taking it as output of the network
2025-10-29 15:22:27 +01:00
Michel Aractingi 7cd710857d update factory with dsrl 2025-10-13 16:12:39 +02:00
Michel Aractingi 5c9bfd57ec Add dsrl policy files 2025-10-13 15:45:16 +02:00
9 changed files with 1578 additions and 29 deletions
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
Args: Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back). current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy. n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details. See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
# Inputs / output structure. # Inputs / output structure.
n_obs_steps: int = 2 n_obs_steps: int = 2
horizon: int = 16 chunk_size: int = 16
n_action_steps: int = 8 n_action_steps: int = 8
normalization_mapping: dict[str, NormalizationMode] = field( normalization_mapping: dict[str, NormalizationMode] = field(
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
# The original implementation doesn't sample frames for the last 7 steps, # The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results. # which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1 drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
# Architecture / modeling. # Architecture / modeling.
# Vision backbone. # Vision backbone.
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
f"Got {self.noise_scheduler_type}." f"Got {self.noise_scheduler_type}."
) )
# Check that the horizon size and U-Net downsampling is compatible. # Check that the chunk size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage. # U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims) downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0: if self.chunk_size % downsampling_factor != 0:
raise ValueError( raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined " "The chunk_size should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}" f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
) )
def get_optimizer_preset(self) -> AdamConfig: def get_optimizer_preset(self) -> AdamConfig:
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
@property @property
def action_delta_indices(self) -> list: def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon)) return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
@property @property
def reward_delta_indices(self) -> None: def reward_delta_indices(self) -> None:
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions return actions
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the This method handles caching a history of observations and an action trajectory generated by the
underlying diffusion model. Here's how it works: underlying diffusion model. Here's how it works:
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
copied `n_obs_steps` times to fill the cache). copied `n_obs_steps` times to fill the cache).
- The diffusion model generates `horizon` steps worth of actions. - The diffusion model generates `chunk_size` steps worth of actions.
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step. - `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
Schematically this looks like: Schematically this looks like:
---------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------
(legend: o = n_obs_steps, h = horizon, a = n_action_steps) (legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h | |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO | |observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
---------------------------------------------------------------------------------------------- ----------------------------------------------------------------------------------------------
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
"horizon" may not the best name to describe what the variable actually means, because this period is this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
noise noise
if noise is not None if noise is not None
else torch.randn( else torch.randn(
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
dtype=dtype, dtype=dtype,
device=device, device=device,
generator=generator, generator=generator,
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
AND/OR AND/OR
"observation.environment_state": (B, n_obs_steps, environment_dim) "observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim) "action": (B, chunk_size, action_dim)
"action_is_pad": (B, horizon) "action_is_pad": (B, chunk_size)
} }
""" """
# Input validation. # Input validation.
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"}) assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1] n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch[ACTION].shape[1] chunk_size = batch[ACTION].shape[1]
assert horizon == self.config.horizon assert chunk_size == self.config.chunk_size
assert n_obs_steps == self.config.n_obs_steps assert n_obs_steps == self.config.n_obs_steps
# Encode image features and concatenate them all together along with the state vector. # Encode image features and concatenate them all together along with the state vector.
@@ -0,0 +1,244 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.optim.optimizers import MultiAdamConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def is_image_feature(key: str) -> bool:
"""Check if a feature key represents an image feature.
Args:
key: The feature key to check
Returns:
True if the key represents an image feature, False otherwise
"""
return key.startswith(OBS_IMAGE)
@dataclass
class ConcurrencyConfig:
"""Configuration for the concurrency of the actor and learner.
Possible values are:
- "threads": Use threads for the actor and learner.
- "processes": Use processes for the actor and learner.
"""
actor: str = "threads"
learner: str = "threads"
@dataclass
class ActorLearnerConfig:
learner_host: str = "127.0.0.1"
learner_port: int = 50051
policy_parameters_push_frequency: int = 4
queue_get_timeout: float = 2
@dataclass
class CriticNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
final_activation: str | None = None
@dataclass
class ActorNetworkConfig:
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
activate_final: bool = True
use_layer_norm: bool = True
@dataclass
class NoiseActorConfig:
"""Configuration for the noise actor in DSRL.
The noise actor outputs noise that gets fed to the diffusion policy.
"""
use_tanh_squash: bool = False # Whether to bound the noise output
std_min: float = 1e-5
std_max: float = 2.0
init_final: float = 0.05
@PreTrainedConfig.register_subclass("dsrl")
@dataclass
class DSRLConfig(PreTrainedConfig):
"""Diffusion Steering via Reinforcement Learning (DSRL) configuration."""
# Mapping of feature types to normalization modes
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ENV": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Statistics for normalizing different types of inputs
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
}
)
# Architecture specifics
# Device to run the model on (e.g., "cuda", "cpu")
device: str = "cpu"
# Device to store the model on
storage_device: str = "cpu"
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
vision_encoder_name: str | None = None
# Whether to freeze the vision encoder during training
freeze_vision_encoder: bool = True
# Hidden dimension size for the image encoder
image_encoder_hidden_dim: int = 32
# Whether to use a shared encoder for actor and critic
shared_encoder: bool = True
# Number of discrete actions, eg for gripper actions
num_discrete_actions: int | None = None
# Dimension of the image embedding pooling
image_embedding_pooling_dim: int = 8
# Name of the action policy
action_policy_name: str = "pi0"
action_policy_weights: str | None = "lerobot/pi0_base"
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Number of steps for offline training
offline_steps: int = 100000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
offline_buffer_capacity: int = 100000
# Whether to use asynchronous prefetching for the buffers
async_prefetch: bool = False
# Number of steps before learning starts
online_step_before_learning: int = 100
# Frequency of policy updates
policy_update_freq: int = 1
# SAC algorithm parameters
discount: float = 0.99
# Initial temperature value
temperature_init: float = 1.0
# Number of critics in the ensemble
num_critics: int = 2
# Number of subsampled critics for training
num_subsample_critics: int | None = None
# Learning rate for the critic network
critic_lr: float = 3e-4
# Learning rate for the actor network
actor_lr: float = 3e-4
# Learning rate for the temperature parameter
temperature_lr: float = 3e-4
# Weight for the critic target update
critic_target_update_weight: float = 0.005
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
utd_ratio: int = 1
# Hidden dimension size for the state encoder
state_encoder_hidden_dim: int = 256
# Dimension of the latent space
latent_dim: int = 256
# Target entropy for the SAC algorithm
target_entropy: float | None = None
# Whether to use backup entropy for the SAC algorithm
use_backup_entropy: bool = True
# Gradient clipping norm for the SAC algorithm
grad_clip_norm: float = 40.0
# Network configuration
# Configuration for the critic network architecture
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the noise critic network architecture
noise_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
# Configuration for the noise actor network architecture
noise_actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
# Configuration for the noise actor specific parameters
noise_actor_kwargs: NoiseActorConfig = field(default_factory=NoiseActorConfig)
# Configuration for actor-learner architecture
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
# Optimizations
use_torch_compile: bool = True
def __post_init__(self):
super().__post_init__()
def get_optimizer_preset(self) -> MultiAdamConfig:
return MultiAdamConfig(
weight_decay=0.0,
optimizer_groups={
"critic_action": {"lr": self.critic_lr},
"critic_noise": {"lr": self.critic_lr},
"noise_actor": {"lr": self.actor_lr},
"temperature": {"lr": self.temperature_lr},
},
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
has_image = any(is_image_feature(key) for key in self.input_features)
has_state = OBS_STATE in self.input_features
if not (has_state or has_image):
raise ValueError(
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
)
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def image_features(self) -> list[str]:
return [key for key in self.input_features if is_image_feature(key)]
@property
def observation_delta_indices(self) -> list:
return None
@property
def action_delta_indices(self) -> list:
return None
@property
def reward_delta_indices(self) -> None:
return None
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,89 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor for DSRL policy.
DSRL uses a similar processing pipeline as SAC since it operates on
state-action transitions. The main difference is that internally it
also works with noise, but that's handled within the policy itself.
"""
from typing import Any
import torch
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
RenameObservationsProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import (
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_dsrl_pre_post_processors(
config: DSRLConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict, dict],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""Create preprocessor and postprocessor pipelines for DSRL policy.
Args:
config: DSRL policy configuration
dataset_stats: Optional dataset statistics for normalization
Returns:
Tuple of (preprocessor, postprocessor) pipelines
"""
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
+16 -2
View File
@@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.configuration_pi05 import PI05Config
@@ -58,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args: Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla". "vqbet", "pi0", "pi0fast", "sac", "reward_classifier", "smolvla", "dsrl".
Returns: Returns:
The policy class corresponding to the given name. The policy class corresponding to the given name.
@@ -106,6 +107,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
return SmolVLAPolicy return SmolVLAPolicy
elif name == "dsrl":
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
return DSRLPolicy
else: else:
raise NotImplementedError(f"Policy with name {name} is not implemented.") raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -120,7 +125,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args: Args:
policy_type: The type of the policy. Supported types include "tdmpc", policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla", "diffusion", "act", "vqbet", "pi0", "pi0fast", "sac", "smolvla",
"reward_classifier". "reward_classifier", "dsrl".
**kwargs: Keyword arguments to be passed to the configuration class constructor. **kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns: Returns:
@@ -149,6 +154,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return SmolVLAConfig(**kwargs) return SmolVLAConfig(**kwargs)
elif policy_type == "reward_classifier": elif policy_type == "reward_classifier":
return RewardClassifierConfig(**kwargs) return RewardClassifierConfig(**kwargs)
elif policy_type == "dsrl":
return DSRLConfig(**kwargs)
else: else:
raise ValueError(f"Policy type '{policy_type}' is not available.") raise ValueError(f"Policy type '{policy_type}' is not available.")
@@ -307,6 +314,13 @@ def make_pre_post_processors(
config=policy_cfg, config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"), dataset_stats=kwargs.get("dataset_stats"),
) )
elif isinstance(policy_cfg, DSRLConfig):
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
processors = make_dsrl_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else: else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
+2 -2
View File
@@ -1148,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
self.eval() self.eval()
@@ -1158,7 +1158,7 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch) state = self.prepare_state(batch)
# Sample actions using the model # Sample actions using the model
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise)
# Unpad actions to actual action dimension # Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
+2 -2
View File
@@ -1120,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
self.eval() self.eval()
@@ -1129,7 +1129,7 @@ class PI05Policy(PreTrainedPolicy):
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05) # Sample actions using the model (no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks) actions = self.model.sample_actions(images, img_masks, tokens, masks, noise)
# Unpad actions to actual action dimension # Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
+10 -5
View File
@@ -86,7 +86,7 @@ class ReplayBuffer:
image_augmentation_function: Callable | None = None, image_augmentation_function: Callable | None = None,
use_drq: bool = True, use_drq: bool = True,
storage_device: str = "cpu", storage_device: str = "cpu",
optimize_memory: bool = False, optimize_memory: bool = True
): ):
""" """
Replay buffer for storing transitions. Replay buffer for storing transitions.
@@ -136,6 +136,7 @@ class ReplayBuffer:
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor] | None = None,
): ):
"""Initialize the storage tensors based on the first transition.""" """Initialize the storage tensors based on the first transition."""
self.capacity = 1000
# Determine shapes from the first transition # Determine shapes from the first transition
state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} state_shapes = {key: val.squeeze(0).shape for key, val in state.items()}
action_shape = action.squeeze(0).shape action_shape = action.squeeze(0).shape
@@ -444,7 +445,7 @@ class ReplayBuffer:
if capacity is None: if capacity is None:
capacity = len(lerobot_dataset) capacity = len(lerobot_dataset)
if capacity < len(lerobot_dataset): if capacity < 1000: #len(lerobot_dataset):
raise ValueError( raise ValueError(
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
) )
@@ -476,13 +477,14 @@ class ReplayBuffer:
and first_transition["complementary_info"] is not None and first_transition["complementary_info"] is not None
): ):
first_complementary_info = { first_complementary_info = {
k: v.to(device) for k, v in first_transition["complementary_info"].items() k: v.to for k, v in first_transition["complementary_info"].items()
} }
replay_buffer._initialize_storage( replay_buffer._initialize_storage(
state=first_state, action=first_action, complementary_info=first_complementary_info state=first_state, action=first_action, complementary_info=first_complementary_info
) )
num_samples = 0
# Fill the buffer with all transitions # Fill the buffer with all transitions
for data in list_transition: for data in list_transition:
for k, v in data.items(): for k, v in data.items():
@@ -503,6 +505,9 @@ class ReplayBuffer:
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None), complementary_info=data.get("complementary_info", None),
) )
num_samples += 1
if num_samples >= 1000:
return replay_buffer
return replay_buffer return replay_buffer
@@ -645,7 +650,7 @@ class ReplayBuffer:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = [] transitions = []
num_frames = len(dataset) num_frames = 1000 # len(dataset)
# Check if the dataset has "next.done" key # Check if the dataset has "next.done" key
sample = dataset[0] sample = dataset[0]
@@ -659,7 +664,7 @@ class ReplayBuffer:
if not has_done_key: if not has_done_key:
print("'next.done' key not found in dataset. Inferring from episode boundaries...") print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)): for i in tqdm(range(1000)): # num_frames)):
current_sample = dataset[i] current_sample = dataset[i]
# ----- 1) Current state ----- # ----- 1) Current state -----