From e4c1a8472d0ea18e9c13e219e04758fe7b30ad06 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Sat, 18 Apr 2026 15:15:59 +0200 Subject: [PATCH] fix(config): update vision encoder model name to lerobot/resnet10 --- src/lerobot/policies/sac/configuration_sac.py | 2 +- .../sac/reward_model/configuration_classifier.py | 2 +- tests/policies/hilserl/test_modeling_classifier.py | 13 ------------- tests/policies/test_sac_policy.py | 2 +- 4 files changed, 3 insertions(+), 16 deletions(-) diff --git a/src/lerobot/policies/sac/configuration_sac.py b/src/lerobot/policies/sac/configuration_sac.py index 8eddc0b46..a6c7d7f21 100644 --- a/src/lerobot/policies/sac/configuration_sac.py +++ b/src/lerobot/policies/sac/configuration_sac.py @@ -122,7 +122,7 @@ class SACConfig(PreTrainedConfig): device: str = "cpu" # Device to store the model on storage_device: str = "cpu" - # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) + # Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10) vision_encoder_name: str | None = None # Whether to freeze the vision encoder during training freeze_vision_encoder: bool = True diff --git a/src/lerobot/policies/sac/reward_model/configuration_classifier.py b/src/lerobot/policies/sac/reward_model/configuration_classifier.py index e23f4da87..d00b3bce8 100644 --- a/src/lerobot/policies/sac/reward_model/configuration_classifier.py +++ b/src/lerobot/policies/sac/reward_model/configuration_classifier.py @@ -31,7 +31,7 @@ class RewardClassifierConfig(PreTrainedConfig): latent_dim: int = 256 image_embedding_pooling_dim: int = 8 dropout_rate: float = 0.1 - model_name: str = "helper2424/resnet10" + model_name: str = "lerobot/resnet10" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py index 6d262c01b..efdfffc87 100644 --- a/tests/policies/hilserl/test_modeling_classifier.py +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature @@ -38,9 +37,6 @@ def test_classifier_output(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_binary_classifier_with_default_params(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -82,9 +78,6 @@ def test_binary_classifier_with_default_params(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_multiclass_classifier(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -124,9 +117,6 @@ def test_multiclass_classifier(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_default_device(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier @@ -139,9 +129,6 @@ def test_default_device(): @skip_if_package_missing("transformers") -@pytest.mark.skip( - reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers" -) def test_explicit_device_setup(): from lerobot.policies.sac.reward_model.modeling_classifier import Classifier diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 43544acf1..e9f0087c6 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -304,7 +304,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_ @pytest.mark.parametrize( "batch_size,state_dim,action_dim,vision_encoder_name", - [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], + [(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], ) @pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") def test_sac_policy_with_pretrained_encoder(