Merge branch 'main' into feat/add_pi

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Steven Palma
2025-09-29 11:51:06 +02:00
committed by GitHub
55 changed files with 153 additions and 440 deletions
@@ -19,7 +19,7 @@ import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
from lerobot.utils.constants import OBS_IMAGE
from lerobot.utils.constants import OBS_IMAGE, REWARD
from tests.utils import require_package
@@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
}
config.normalization_mapping = {
"VISUAL": NormalizationMode.IDENTITY,
@@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params():
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
}
images, labels = classifier.extract_images_and_labels(input)
@@ -87,7 +87,7 @@ def test_multiclass_classifier():
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
}
config.num_cameras = 1
config.num_classes = num_classes
@@ -97,7 +97,7 @@ def test_multiclass_classifier():
input = {
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
"next.reward": torch.rand((batch_size, num_classes)),
REWARD: torch.rand((batch_size, num_classes)),
}
images, labels = classifier.extract_images_and_labels(input)
-1
View File
@@ -69,7 +69,6 @@ def test_sac_config_default_initialization():
# Training parameters
assert config.online_steps == 1000000
assert config.online_env_seed == 10000
assert config.online_buffer_capacity == 100000
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False