mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
Merge branch 'main' into feat/add_pi
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.constants import OBS_IMAGE, REWARD
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_binary_classifier_with_default_params():
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(1,)),
|
||||
}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
@@ -58,7 +58,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
REWARD: torch.randint(low=0, high=2, size=(batch_size,)).float(),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
@@ -87,7 +87,7 @@ def test_multiclass_classifier():
|
||||
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
REWARD: PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)),
|
||||
}
|
||||
config.num_cameras = 1
|
||||
config.num_classes = num_classes
|
||||
@@ -97,7 +97,7 @@ def test_multiclass_classifier():
|
||||
|
||||
input = {
|
||||
OBS_IMAGE: torch.rand((batch_size, 3, 128, 128)),
|
||||
"next.reward": torch.rand((batch_size, num_classes)),
|
||||
REWARD: torch.rand((batch_size, num_classes)),
|
||||
}
|
||||
|
||||
images, labels = classifier.extract_images_and_labels(input)
|
||||
|
||||
@@ -69,7 +69,6 @@ def test_sac_config_default_initialization():
|
||||
|
||||
# Training parameters
|
||||
assert config.online_steps == 1000000
|
||||
assert config.online_env_seed == 10000
|
||||
assert config.online_buffer_capacity == 100000
|
||||
assert config.offline_buffer_capacity == 100000
|
||||
assert config.async_prefetch is False
|
||||
|
||||
Reference in New Issue
Block a user