mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c11d8f1bb6 | |||
| 6001b2c3ad | |||
| a5b29d4301 | |||
| a4aa316470 |
@@ -83,11 +83,11 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
- name: Remove Tags with Git dependencies
|
- name: Remove Tags with Git dependencies
|
||||||
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||||
run: |
|
run: |
|
||||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||||
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||||
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
|
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||||
|
|
||||||
- name: Install build dependencies
|
- name: Install build dependencies
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ jobs:
|
|||||||
echo "Dependencies unbound:" && cat pyproject.toml
|
echo "Dependencies unbound:" && cat pyproject.toml
|
||||||
|
|
||||||
- name: Install lerobot with all extras
|
- name: Install lerobot with all extras
|
||||||
run: uv sync --all-extras
|
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||||
|
|
||||||
- name: Run pytest (all extras)
|
- name: Run pytest (all extras)
|
||||||
run: uv run pytest tests -vv
|
run: uv run pytest tests -vv
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ For a full list of optional dependencies, see:
|
|||||||
https://pypi.org/project/lerobot/
|
https://pypi.org/project/lerobot/
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
|
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||||
>
|
>
|
||||||
> This will be solved in the next patch release
|
> This will be solved in the next patch release
|
||||||
|
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ For a full list of optional dependencies, see:
|
|||||||
https://pypi.org/project/lerobot/
|
https://pypi.org/project/lerobot/
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
|
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|||||||
@@ -28,11 +28,6 @@ LIBERO is now part of our **multi-eval supported simulation**, meaning you can b
|
|||||||
To Install LIBERO, after following LeRobot official instructions, just do:
|
To Install LIBERO, after following LeRobot official instructions, just do:
|
||||||
`pip install -e ".[libero]"`
|
`pip install -e ".[libero]"`
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> For lerobot 0.4.0, if you want to install libero tag, you will have to do: `pip install "lerobot[libero]@git+https://github.com/huggingface/lerobot.git"`.
|
|
||||||
>
|
|
||||||
> This will be solved in the next patch release
|
|
||||||
|
|
||||||
### Single-suite evaluation
|
### Single-suite evaluation
|
||||||
|
|
||||||
Evaluate a policy on one LIBERO suite:
|
Evaluate a policy on one LIBERO suite:
|
||||||
|
|||||||
@@ -940,11 +940,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
"""
|
||||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
Query dataset for indices across keys, skipping video keys.
|
||||||
for key, q_idx in query_indices.items()
|
|
||||||
if key not in self.meta.video_keys
|
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||||
}
|
|
||||||
|
Args:
|
||||||
|
query_indices: Dict mapping keys to index lists to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with stacked tensors of queried data (video keys excluded)
|
||||||
|
"""
|
||||||
|
result: dict = {}
|
||||||
|
for key, q_idx in query_indices.items():
|
||||||
|
if key in self.meta.video_keys:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result[key] = torch.stack(self.hf_dataset[key][q_idx])
|
||||||
|
except (KeyError, TypeError, IndexError):
|
||||||
|
result[key] = torch.stack(self.hf_dataset[q_idx][key])
|
||||||
|
return result
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
|||||||
@@ -237,9 +237,10 @@ class LiberoEnv(gym.Env):
|
|||||||
def reset(self, seed=None, **kwargs):
|
def reset(self, seed=None, **kwargs):
|
||||||
super().reset(seed=seed)
|
super().reset(seed=seed)
|
||||||
self._env.seed(seed)
|
self._env.seed(seed)
|
||||||
|
raw_obs = self._env.reset()
|
||||||
if self.init_states and self._init_states is not None:
|
if self.init_states and self._init_states is not None:
|
||||||
self._env.set_init_state(self._init_states[self._init_state_id])
|
self._env.set_init_state(self._init_states[self._init_state_id])
|
||||||
raw_obs = self._env.reset()
|
raw_obs = self._env.env._get_observations()
|
||||||
|
|
||||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
chunk_size: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
@@ -105,7 +105,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Inputs / output structure.
|
# Inputs / output structure.
|
||||||
n_obs_steps: int = 2
|
n_obs_steps: int = 2
|
||||||
chunk_size: int = 16
|
horizon: int = 16
|
||||||
n_action_steps: int = 8
|
n_action_steps: int = 8
|
||||||
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
@@ -118,7 +118,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# The original implementation doesn't sample frames for the last 7 steps,
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
# which avoids excessive padding and leads to improved training results.
|
# which avoids excessive padding and leads to improved training results.
|
||||||
drop_n_last_frames: int = 7 # chunk_size - n_action_steps - n_obs_steps + 1
|
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
|
||||||
|
|
||||||
# Architecture / modeling.
|
# Architecture / modeling.
|
||||||
# Vision backbone.
|
# Vision backbone.
|
||||||
@@ -180,13 +180,13 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
f"Got {self.noise_scheduler_type}."
|
f"Got {self.noise_scheduler_type}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check that the chunk size and U-Net downsampling is compatible.
|
# Check that the horizon size and U-Net downsampling is compatible.
|
||||||
# U-Net downsamples by 2 with each stage.
|
# U-Net downsamples by 2 with each stage.
|
||||||
downsampling_factor = 2 ** len(self.down_dims)
|
downsampling_factor = 2 ** len(self.down_dims)
|
||||||
if self.chunk_size % downsampling_factor != 0:
|
if self.horizon % downsampling_factor != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The chunk_size should be an integer multiple of the downsampling factor (which is determined "
|
"The horizon should be an integer multiple of the downsampling factor (which is determined "
|
||||||
f"by `len(down_dims)`). Got {self.chunk_size=} and {self.down_dims=}"
|
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> AdamConfig:
|
def get_optimizer_preset(self) -> AdamConfig:
|
||||||
@@ -231,7 +231,7 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list:
|
def action_delta_indices(self) -> list:
|
||||||
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.chunk_size))
|
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reward_delta_indices(self) -> None:
|
def reward_delta_indices(self) -> None:
|
||||||
|
|||||||
@@ -99,25 +99,25 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method handles caching a history of observations and an action trajectory generated by the
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
underlying diffusion model. Here's how it works:
|
underlying diffusion model. Here's how it works:
|
||||||
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
- `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is
|
||||||
copied `n_obs_steps` times to fill the cache).
|
copied `n_obs_steps` times to fill the cache).
|
||||||
- The diffusion model generates `chunk_size` steps worth of actions.
|
- The diffusion model generates `horizon` steps worth of actions.
|
||||||
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
- `n_action_steps` worth of actions are actually kept for execution, starting from the current step.
|
||||||
Schematically this looks like:
|
Schematically this looks like:
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
(legend: o = n_obs_steps, c = chunk_size, a = n_action_steps)
|
(legend: o = n_obs_steps, h = horizon, a = n_action_steps)
|
||||||
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
|timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... | n-o+h |
|
||||||
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
|observation is used | YES | YES | YES | YES | NO | NO | NO | NO | NO |
|
||||||
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
|action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES |
|
||||||
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
|action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO |
|
||||||
----------------------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------------------
|
||||||
Note that this means we require: `n_action_steps <= chunk_size - n_obs_steps + 1`. Also, note that
|
Note that this means we require: `n_action_steps <= horizon - n_obs_steps + 1`. Also, note that
|
||||||
this period is
|
"horizon" may not the best name to describe what the variable actually means, because this period is
|
||||||
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
|
||||||
"""
|
"""
|
||||||
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
|
||||||
@@ -213,7 +213,7 @@ class DiffusionModel(nn.Module):
|
|||||||
noise
|
noise
|
||||||
if noise is not None
|
if noise is not None
|
||||||
else torch.randn(
|
else torch.randn(
|
||||||
size=(batch_size, self.config.chunk_size, self.config.action_feature.shape[0]),
|
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
@@ -309,16 +309,16 @@ class DiffusionModel(nn.Module):
|
|||||||
AND/OR
|
AND/OR
|
||||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
|
||||||
"action": (B, chunk_size, action_dim)
|
"action": (B, horizon, action_dim)
|
||||||
"action_is_pad": (B, chunk_size)
|
"action_is_pad": (B, horizon)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
|
||||||
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
|
||||||
n_obs_steps = batch[OBS_STATE].shape[1]
|
n_obs_steps = batch[OBS_STATE].shape[1]
|
||||||
chunk_size = batch[ACTION].shape[1]
|
horizon = batch[ACTION].shape[1]
|
||||||
assert chunk_size == self.config.chunk_size
|
assert horizon == self.config.horizon
|
||||||
assert n_obs_steps == self.config.n_obs_steps
|
assert n_obs_steps == self.config.n_obs_steps
|
||||||
|
|
||||||
# Encode image features and concatenate them all together along with the state vector.
|
# Encode image features and concatenate them all together along with the state vector.
|
||||||
|
|||||||
@@ -1,242 +0,0 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
|
||||||
from lerobot.configs.types import NormalizationMode
|
|
||||||
from lerobot.optim.optimizers import MultiAdamConfig
|
|
||||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_feature(key: str) -> bool:
|
|
||||||
"""Check if a feature key represents an image feature.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The feature key to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the key represents an image feature, False otherwise
|
|
||||||
"""
|
|
||||||
return key.startswith(OBS_IMAGE)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConcurrencyConfig:
|
|
||||||
"""Configuration for the concurrency of the actor and learner.
|
|
||||||
Possible values are:
|
|
||||||
- "threads": Use threads for the actor and learner.
|
|
||||||
- "processes": Use processes for the actor and learner.
|
|
||||||
"""
|
|
||||||
|
|
||||||
actor: str = "threads"
|
|
||||||
learner: str = "threads"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActorLearnerConfig:
|
|
||||||
learner_host: str = "127.0.0.1"
|
|
||||||
learner_port: int = 50051
|
|
||||||
policy_parameters_push_frequency: int = 4
|
|
||||||
queue_get_timeout: float = 2
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CriticNetworkConfig:
|
|
||||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
|
||||||
activate_final: bool = True
|
|
||||||
final_activation: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ActorNetworkConfig:
|
|
||||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
|
||||||
activate_final: bool = True
|
|
||||||
use_layer_norm: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NoiseActorConfig:
|
|
||||||
"""Configuration for the noise actor in DSRL.
|
|
||||||
The noise actor outputs noise that gets fed to the diffusion policy.
|
|
||||||
"""
|
|
||||||
|
|
||||||
use_tanh_squash: bool = False # Whether to bound the noise output
|
|
||||||
std_min: float = 1e-5
|
|
||||||
std_max: float = 2.0
|
|
||||||
init_final: float = 0.05
|
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("dsrl")
|
|
||||||
@dataclass
|
|
||||||
class DSRLConfig(PreTrainedConfig):
|
|
||||||
"""Diffusion Steering via Reinforcement Learning (DSRL) configuration."""
|
|
||||||
|
|
||||||
# Mapping of feature types to normalization modes
|
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
"VISUAL": NormalizationMode.MEAN_STD,
|
|
||||||
"STATE": NormalizationMode.MIN_MAX,
|
|
||||||
"ENV": NormalizationMode.MIN_MAX,
|
|
||||||
"ACTION": NormalizationMode.MIN_MAX,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Statistics for normalizing different types of inputs
|
|
||||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
OBS_IMAGE: {
|
|
||||||
"mean": [0.485, 0.456, 0.406],
|
|
||||||
"std": [0.229, 0.224, 0.225],
|
|
||||||
},
|
|
||||||
OBS_STATE: {
|
|
||||||
"min": [0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0],
|
|
||||||
},
|
|
||||||
ACTION: {
|
|
||||||
"min": [0.0, 0.0, 0.0],
|
|
||||||
"max": [1.0, 1.0, 1.0],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Architecture specifics
|
|
||||||
# Device to run the model on (e.g., "cuda", "cpu")
|
|
||||||
device: str = "cpu"
|
|
||||||
# Device to store the model on
|
|
||||||
storage_device: str = "cpu"
|
|
||||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
|
||||||
vision_encoder_name: str | None = None
|
|
||||||
# Whether to freeze the vision encoder during training
|
|
||||||
freeze_vision_encoder: bool = True
|
|
||||||
# Hidden dimension size for the image encoder
|
|
||||||
image_encoder_hidden_dim: int = 32
|
|
||||||
# Whether to use a shared encoder for actor and critic
|
|
||||||
shared_encoder: bool = True
|
|
||||||
# Number of discrete actions, eg for gripper actions
|
|
||||||
num_discrete_actions: int | None = None
|
|
||||||
# Dimension of the image embedding pooling
|
|
||||||
image_embedding_pooling_dim: int = 8
|
|
||||||
|
|
||||||
# Name of the action policy
|
|
||||||
action_policy_name: str = "pi0"
|
|
||||||
action_policy_weights: str | None = "lerobot/pi0_base"
|
|
||||||
|
|
||||||
# Training parameter
|
|
||||||
# Number of steps for online training
|
|
||||||
online_steps: int = 1000000
|
|
||||||
# Capacity of the online replay buffer
|
|
||||||
online_buffer_capacity: int = 100000
|
|
||||||
# Capacity of the offline replay buffer
|
|
||||||
offline_buffer_capacity: int = 100000
|
|
||||||
# Whether to use asynchronous prefetching for the buffers
|
|
||||||
async_prefetch: bool = False
|
|
||||||
# Number of steps before learning starts
|
|
||||||
online_step_before_learning: int = 100
|
|
||||||
# Frequency of policy updates
|
|
||||||
policy_update_freq: int = 1
|
|
||||||
|
|
||||||
# SAC algorithm parameters
|
|
||||||
discount: float = 0.99
|
|
||||||
# Initial temperature value
|
|
||||||
temperature_init: float = 1.0
|
|
||||||
# Number of critics in the ensemble
|
|
||||||
num_critics: int = 2
|
|
||||||
# Number of subsampled critics for training
|
|
||||||
num_subsample_critics: int | None = None
|
|
||||||
# Learning rate for the critic network
|
|
||||||
critic_lr: float = 3e-4
|
|
||||||
# Learning rate for the actor network
|
|
||||||
actor_lr: float = 3e-4
|
|
||||||
# Learning rate for the temperature parameter
|
|
||||||
temperature_lr: float = 3e-4
|
|
||||||
# Weight for the critic target update
|
|
||||||
critic_target_update_weight: float = 0.005
|
|
||||||
# Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1)
|
|
||||||
utd_ratio: int = 1
|
|
||||||
# Hidden dimension size for the state encoder
|
|
||||||
state_encoder_hidden_dim: int = 256
|
|
||||||
# Dimension of the latent space
|
|
||||||
latent_dim: int = 256
|
|
||||||
# Target entropy for the SAC algorithm
|
|
||||||
target_entropy: float | None = None
|
|
||||||
# Whether to use backup entropy for the SAC algorithm
|
|
||||||
use_backup_entropy: bool = True
|
|
||||||
# Gradient clipping norm for the SAC algorithm
|
|
||||||
grad_clip_norm: float = 40.0
|
|
||||||
|
|
||||||
# Network configuration
|
|
||||||
# Configuration for the critic network architecture
|
|
||||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
|
||||||
# Configuration for the noise critic network architecture
|
|
||||||
noise_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
|
||||||
# Configuration for the noise actor network architecture
|
|
||||||
noise_actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
|
||||||
# Configuration for the noise actor specific parameters
|
|
||||||
noise_actor_kwargs: NoiseActorConfig = field(default_factory=NoiseActorConfig)
|
|
||||||
# Configuration for actor-learner architecture
|
|
||||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
|
||||||
# Configuration for concurrency settings (you can use threads or processes for the actor and learner)
|
|
||||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
|
||||||
|
|
||||||
# Optimizations
|
|
||||||
use_torch_compile: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
super().__post_init__()
|
|
||||||
|
|
||||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
|
||||||
return MultiAdamConfig(
|
|
||||||
weight_decay=0.0,
|
|
||||||
optimizer_groups={
|
|
||||||
"critic_action": {"lr": self.critic_lr},
|
|
||||||
"critic_noise": {"lr": self.critic_lr},
|
|
||||||
"noise_actor": {"lr": self.actor_lr},
|
|
||||||
"temperature": {"lr": self.temperature_lr},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_scheduler_preset(self) -> None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def validate_features(self) -> None:
|
|
||||||
has_image = any(is_image_feature(key) for key in self.input_features)
|
|
||||||
has_state = OBS_STATE in self.input_features
|
|
||||||
|
|
||||||
if not (has_state or has_image):
|
|
||||||
raise ValueError(
|
|
||||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
|
||||||
)
|
|
||||||
|
|
||||||
if ACTION not in self.output_features:
|
|
||||||
raise ValueError("You must provide 'action' in the output features")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def image_features(self) -> list[str]:
|
|
||||||
return [key for key in self.input_features if is_image_feature(key)]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def observation_delta_indices(self) -> list:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def action_delta_indices(self) -> list:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def reward_delta_indices(self) -> None:
|
|
||||||
return None
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,89 +0,0 @@
|
|||||||
# !/usr/bin/env python
|
|
||||||
|
|
||||||
# Copyright 2025 The HuggingFace Inc. team.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Processor for DSRL policy.
|
|
||||||
|
|
||||||
DSRL uses a similar processing pipeline as SAC since it operates on
|
|
||||||
state-action transitions. The main difference is that internally it
|
|
||||||
also works with noise, but that's handled within the policy itself.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
|
||||||
from lerobot.processor import (
|
|
||||||
AddBatchDimensionProcessorStep,
|
|
||||||
DeviceProcessorStep,
|
|
||||||
NormalizerProcessorStep,
|
|
||||||
PolicyAction,
|
|
||||||
PolicyProcessorPipeline,
|
|
||||||
RenameObservationsProcessorStep,
|
|
||||||
UnnormalizerProcessorStep,
|
|
||||||
)
|
|
||||||
from lerobot.processor.converters import (
|
|
||||||
policy_action_to_transition,
|
|
||||||
transition_to_policy_action,
|
|
||||||
)
|
|
||||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
|
||||||
|
|
||||||
|
|
||||||
def make_dsrl_pre_post_processors(
|
|
||||||
config: DSRLConfig,
|
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
|
||||||
) -> tuple[
|
|
||||||
PolicyProcessorPipeline[dict, dict],
|
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
|
||||||
]:
|
|
||||||
"""Create preprocessor and postprocessor pipelines for DSRL policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: DSRL policy configuration
|
|
||||||
dataset_stats: Optional dataset statistics for normalization
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (preprocessor, postprocessor) pipelines
|
|
||||||
"""
|
|
||||||
input_steps = [
|
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
|
||||||
AddBatchDimensionProcessorStep(),
|
|
||||||
DeviceProcessorStep(device=config.device),
|
|
||||||
NormalizerProcessorStep(
|
|
||||||
features={**config.input_features, **config.output_features},
|
|
||||||
norm_map=config.normalization_mapping,
|
|
||||||
stats=dataset_stats,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
output_steps = [
|
|
||||||
UnnormalizerProcessorStep(
|
|
||||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
|
||||||
),
|
|
||||||
DeviceProcessorStep(device="cpu"),
|
|
||||||
]
|
|
||||||
return (
|
|
||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
|
||||||
steps=input_steps,
|
|
||||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
|
||||||
),
|
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
|
||||||
steps=output_steps,
|
|
||||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
|
||||||
to_transition=policy_action_to_transition,
|
|
||||||
to_output=transition_to_policy_action,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -30,7 +30,6 @@ from lerobot.envs.configs import EnvConfig
|
|||||||
from lerobot.envs.utils import env_to_policy_features
|
from lerobot.envs.utils import env_to_policy_features
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig
|
|
||||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
@@ -60,7 +59,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "dsrl".
|
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The policy class corresponding to the given name.
|
The policy class corresponding to the given name.
|
||||||
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
|||||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
return SmolVLAPolicy
|
return SmolVLAPolicy
|
||||||
elif name == "dsrl":
|
|
||||||
from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy
|
|
||||||
|
|
||||||
return DSRLPolicy
|
|
||||||
elif name == "groot":
|
elif name == "groot":
|
||||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||||
|
|
||||||
@@ -126,7 +121,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
Args:
|
Args:
|
||||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||||
"reward_classifier", "dsrl".
|
"reward_classifier".
|
||||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -153,8 +148,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return SmolVLAConfig(**kwargs)
|
return SmolVLAConfig(**kwargs)
|
||||||
elif policy_type == "reward_classifier":
|
elif policy_type == "reward_classifier":
|
||||||
return RewardClassifierConfig(**kwargs)
|
return RewardClassifierConfig(**kwargs)
|
||||||
elif policy_type == "dsrl":
|
|
||||||
return DSRLConfig(**kwargs)
|
|
||||||
elif policy_type == "groot":
|
elif policy_type == "groot":
|
||||||
return GrootConfig(**kwargs)
|
return GrootConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -328,21 +321,6 @@ def make_pre_post_processors(
|
|||||||
config=policy_cfg,
|
config=policy_cfg,
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
dataset_stats=kwargs.get("dataset_stats"),
|
||||||
)
|
)
|
||||||
elif isinstance(policy_cfg, DSRLConfig):
|
|
||||||
from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_dsrl_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
|
||||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
|
||||||
|
|
||||||
processors = make_groot_pre_post_processors(
|
|
||||||
config=policy_cfg,
|
|
||||||
dataset_stats=kwargs.get("dataset_stats"),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(policy_cfg, GrootConfig):
|
elif isinstance(policy_cfg, GrootConfig):
|
||||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||||
|
|||||||
@@ -1148,7 +1148,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1158,7 +1158,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
|
|
||||||
# Sample actions using the model
|
# Sample actions using the model
|
||||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise)
|
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -1120,7 +1120,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1129,7 +1129,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
|
||||||
# Sample actions using the model (no separate state needed for PI05)
|
# Sample actions using the model (no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, noise)
|
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -0,0 +1,148 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from lerobot.envs.factory import make_env, make_env_config
|
||||||
|
|
||||||
|
# Set MuJoCo rendering backend before importing environment
|
||||||
|
os.environ["MUJOCO_GL"] = "egl"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_observations_equal(obs1, obs2, path="", atol=1e-8):
|
||||||
|
"""
|
||||||
|
Recursively compare two observations and assert they are equal.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs1: First observation (dict or numpy array)
|
||||||
|
obs2: Second observation (dict or numpy array)
|
||||||
|
path: Current path in nested structure (for error messages)
|
||||||
|
atol: Absolute tolerance for numpy array comparisons
|
||||||
|
"""
|
||||||
|
if isinstance(obs1, dict) and isinstance(obs2, dict):
|
||||||
|
assert obs1.keys() == obs2.keys(), f"Keys differ at {path}: {obs1.keys()} != {obs2.keys()}"
|
||||||
|
for key in obs1:
|
||||||
|
assert_observations_equal(obs1[key], obs2[key], path=f"{path}.{key}" if path else key, atol=atol)
|
||||||
|
elif isinstance(obs1, np.ndarray) and isinstance(obs2, np.ndarray):
|
||||||
|
assert obs1.shape == obs2.shape, f"Shape mismatch at {path}: {obs1.shape} != {obs2.shape}"
|
||||||
|
assert obs1.dtype == obs2.dtype, f"Dtype mismatch at {path}: {obs1.dtype} != {obs2.dtype}"
|
||||||
|
assert np.allclose(obs1, obs2, atol=atol), (
|
||||||
|
f"Array values differ at {path}: max abs diff = {np.abs(obs1 - obs2).max()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert type(obs1) is type(obs2), f"Type mismatch at {path}: {type(obs1)} != {type(obs2)}"
|
||||||
|
assert obs1 == obs2, f"Values differ at {path}: {obs1} != {obs2}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_libero_env_creation():
|
||||||
|
"""Test that the libero environment can be created successfully."""
|
||||||
|
config = make_env_config("libero", task="libero_spatial")
|
||||||
|
envs_dict = make_env(config)
|
||||||
|
|
||||||
|
assert "libero_spatial" in envs_dict
|
||||||
|
assert 0 in envs_dict["libero_spatial"]
|
||||||
|
|
||||||
|
env = envs_dict["libero_spatial"][0]
|
||||||
|
assert env is not None
|
||||||
|
|
||||||
|
# Test basic reset
|
||||||
|
observation, info = env.reset(seed=42)
|
||||||
|
assert observation is not None
|
||||||
|
assert info is not None
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_libero_reset_determinism():
|
||||||
|
"""Test that resetting with the same seed produces identical observations."""
|
||||||
|
config = make_env_config("libero", task="libero_spatial")
|
||||||
|
envs_dict = make_env(config)
|
||||||
|
env = envs_dict["libero_spatial"][0]
|
||||||
|
|
||||||
|
# Reset multiple times with the same seed
|
||||||
|
obs1, info1 = env.reset(seed=42)
|
||||||
|
obs2, info2 = env.reset(seed=42)
|
||||||
|
obs3, info3 = env.reset(seed=42)
|
||||||
|
|
||||||
|
# All observations should be identical
|
||||||
|
assert_observations_equal(obs1, obs2)
|
||||||
|
assert_observations_equal(obs1, obs3)
|
||||||
|
assert_observations_equal(obs2, obs3)
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_libero_step_determinism():
|
||||||
|
"""Test that step() is deterministic when resetting with the same seed."""
|
||||||
|
config = make_env_config("libero", task="libero_spatial")
|
||||||
|
envs_dict = make_env(config)
|
||||||
|
env = envs_dict["libero_spatial"][0]
|
||||||
|
|
||||||
|
seed = 42
|
||||||
|
|
||||||
|
# First rollout
|
||||||
|
obs1, info1 = env.reset(seed=seed)
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs_after_step1, reward1, terminated1, truncated1, info_step1 = env.step(action)
|
||||||
|
|
||||||
|
# Second rollout with identical seed and action
|
||||||
|
obs2, info2 = env.reset(seed=seed)
|
||||||
|
obs_after_step2, reward2, terminated2, truncated2, info_step2 = env.step(action)
|
||||||
|
|
||||||
|
# Initial observations should be identical
|
||||||
|
assert_observations_equal(obs1, obs2)
|
||||||
|
|
||||||
|
# Post-step observations should be identical
|
||||||
|
assert_observations_equal(obs_after_step1, obs_after_step2)
|
||||||
|
|
||||||
|
# Rewards and termination flags should be identical
|
||||||
|
assert np.allclose(reward1, reward2), f"Rewards differ: {reward1} != {reward2}"
|
||||||
|
assert np.array_equal(terminated1, terminated2), (
|
||||||
|
f"Terminated flags differ: {terminated1} != {terminated2}"
|
||||||
|
)
|
||||||
|
assert np.array_equal(truncated1, truncated2), f"Truncated flags differ: {truncated1} != {truncated2}"
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("task", ["libero_spatial", "libero_object", "libero_goal", "libero_10"])
|
||||||
|
def test_libero_tasks(task):
|
||||||
|
"""Test that different libero tasks can be created and used."""
|
||||||
|
config = make_env_config("libero", task=task)
|
||||||
|
envs_dict = make_env(config)
|
||||||
|
|
||||||
|
assert task in envs_dict
|
||||||
|
assert 0 in envs_dict[task]
|
||||||
|
|
||||||
|
env = envs_dict[task][0]
|
||||||
|
observation, info = env.reset(seed=42)
|
||||||
|
|
||||||
|
assert observation is not None
|
||||||
|
assert info is not None
|
||||||
|
|
||||||
|
# Take a step
|
||||||
|
action = env.action_space.sample()
|
||||||
|
obs, reward, terminated, truncated, info = env.step(action)
|
||||||
|
|
||||||
|
assert obs is not None
|
||||||
|
assert reward is not None
|
||||||
|
assert isinstance(terminated, (bool, np.ndarray))
|
||||||
|
assert isinstance(truncated, (bool, np.ndarray))
|
||||||
|
|
||||||
|
env.close()
|
||||||
Reference in New Issue
Block a user