Merge branch 'main' into feat/add_pi_conflicts_main

This commit is contained in:
Steven Palma
2025-09-26 13:38:31 +02:00
68 changed files with 963 additions and 865 deletions
@@ -19,6 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from tests.utils import require_package
@@ -41,7 +42,7 @@ def test_binary_classifier_with_default_params():
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
@@ -56,7 +57,7 @@ def test_binary_classifier_with_default_params():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
@@ -83,7 +84,7 @@ def test_multiclass_classifier():
num_classes = 5
config = RewardClassifierConfig()
config.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
@@ -95,7 +96,7 @@ def test_multiclass_classifier():
batch_size = 10
input = {
"observation.image": torch.rand((batch_size, 3, 128, 128)),
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
}
+10 -10
View File
@@ -41,7 +41,7 @@ from lerobot.policies.factory import (
make_pre_post_processors,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION, OBS_STATE
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -52,19 +52,19 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
# Create only one camera input which is squared to fit all current policy constraints
# e.g. vqbet and tdmpc works with one camera only, and tdmpc requires it to be squared
camera_features = {
"observation.images.laptop": {
f"{OBS_IMAGES}.laptop": {
"shape": (84, 84, 3),
"names": ["height", "width", "channels"],
"info": None,
},
}
motor_features = {
"action": {
ACTION: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"observation.state": {
OBS_STATE: {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
@@ -281,13 +281,13 @@ def test_multikey_construction(multikey: bool):
preventing erroneous creation of the policy object.
"""
input_features = {
"observation.state": PolicyFeature(
OBS_STATE: PolicyFeature(
type=FeatureType.STATE,
shape=(10,),
),
}
output_features = {
"action": PolicyFeature(
ACTION: PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
),
@@ -297,14 +297,14 @@ def test_multikey_construction(multikey: bool):
"""Simulates the complete state/action is constructed from more granular multiple
keys, of the same type as the overall state/action"""
input_features = {}
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
input_features[f"{OBS_STATE}.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[f"{OBS_STATE}.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
input_features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
output_features = {}
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
output_features["action"] = PolicyFeature(
output_features[ACTION] = PolicyFeature(
type=FeatureType.ACTION,
shape=(5,),
)
+11 -10
View File
@@ -25,6 +25,7 @@ from lerobot.policies.sac.configuration_sac import (
PolicyConfig,
SACConfig,
)
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
def test_sac_config_default_initialization():
@@ -37,15 +38,15 @@ def test_sac_config_default_initialization():
"ACTION": NormalizationMode.MIN_MAX,
}
assert config.dataset_stats == {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -90,15 +91,15 @@ def test_sac_config_default_initialization():
# Dataset stats defaults
expected_dataset_stats = {
"observation.image": {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
"observation.state": {
OBS_STATE: {
"min": [0.0, 0.0],
"max": [1.0, 1.0],
},
"action": {
ACTION: {
"min": [0.0, 0.0, 0.0],
"max": [1.0, 1.0, 1.0],
},
@@ -191,8 +192,8 @@ def test_sac_config_custom_initialization():
def test_validate_features():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
config.validate_features()
@@ -200,7 +201,7 @@ def test_validate_features():
def test_validate_features_missing_observation():
config = SACConfig(
input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(
ValueError, match="You must provide either 'observation.state' or an image observation"
@@ -210,7 +211,7 @@ def test_validate_features_missing_observation():
def test_validate_features_missing_action():
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,))},
output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))},
)
with pytest.raises(ValueError, match="You must provide 'action' in the output features"):
+15 -14
View File
@@ -23,6 +23,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
try:
@@ -85,14 +86,14 @@ def test_sac_policy_with_default_args():
def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor:
return {
"observation.image": torch.randn(batch_size, 3, 84, 84),
"observation.state": torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
}
@@ -104,7 +105,7 @@ def create_default_train_batch(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_state(batch_size, state_dim),
"next_state": create_dummy_state(batch_size, state_dim),
@@ -116,7 +117,7 @@ def create_train_batch_with_visual_input(
batch_size: int = 8, state_dim: int = 10, action_dim: int = 10
) -> dict[str, Tensor]:
return {
"action": create_dummy_action(batch_size, action_dim),
ACTION: create_dummy_action(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": create_dummy_with_visual_input(batch_size, state_dim),
"next_state": create_dummy_with_visual_input(batch_size, state_dim),
@@ -126,14 +127,14 @@ def create_train_batch_with_visual_input(
def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
OBS_STATE: torch.randn(batch_size, state_dim),
}
def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]:
return {
"observation.state": torch.randn(batch_size, state_dim),
"observation.image": torch.randn(batch_size, 3, 84, 84),
OBS_STATE: torch.randn(batch_size, state_dim),
OBS_IMAGE: torch.randn(batch_size, 3, 84, 84),
}
@@ -180,14 +181,14 @@ def create_default_config(
action_dim += 1
config = SACConfig(
input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))},
dataset_stats={
"observation.state": {
OBS_STATE: {
"min": [0.0] * state_dim,
"max": [1.0] * state_dim,
},
"action": {
ACTION: {
"min": [0.0] * continuous_action_dim,
"max": [1.0] * continuous_action_dim,
},
@@ -205,8 +206,8 @@ def create_config_with_visual_input(
continuous_action_dim=continuous_action_dim,
has_discrete_action=has_discrete_action,
)
config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats["observation.image"] = {
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1),
"std": torch.randn(3, 1, 1),
}