mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
fix(config): update vision encoder model name to lerobot/resnet10
This commit is contained in:
@@ -122,7 +122,7 @@ class SACConfig(PreTrainedConfig):
|
|||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
# Device to store the model on
|
# Device to store the model on
|
||||||
storage_device: str = "cpu"
|
storage_device: str = "cpu"
|
||||||
# Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10)
|
# Name of the vision encoder model (Set to "lerobot/resnet10" for hil serl resnet10)
|
||||||
vision_encoder_name: str | None = None
|
vision_encoder_name: str | None = None
|
||||||
# Whether to freeze the vision encoder during training
|
# Whether to freeze the vision encoder during training
|
||||||
freeze_vision_encoder: bool = True
|
freeze_vision_encoder: bool = True
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class RewardClassifierConfig(PreTrainedConfig):
|
|||||||
latent_dim: int = 256
|
latent_dim: int = 256
|
||||||
image_embedding_pooling_dim: int = 8
|
image_embedding_pooling_dim: int = 8
|
||||||
dropout_rate: float = 0.1
|
dropout_rate: float = 0.1
|
||||||
model_name: str = "helper2424/resnet10"
|
model_name: str = "lerobot/resnet10"
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
model_type: str = "cnn" # "transformer" or "cnn"
|
model_type: str = "cnn" # "transformer" or "cnn"
|
||||||
num_cameras: int = 2
|
num_cameras: int = 2
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||||
@@ -38,9 +37,6 @@ def test_classifier_output():
|
|||||||
|
|
||||||
|
|
||||||
@skip_if_package_missing("transformers")
|
@skip_if_package_missing("transformers")
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
|
||||||
)
|
|
||||||
def test_binary_classifier_with_default_params():
|
def test_binary_classifier_with_default_params():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -82,9 +78,6 @@ def test_binary_classifier_with_default_params():
|
|||||||
|
|
||||||
|
|
||||||
@skip_if_package_missing("transformers")
|
@skip_if_package_missing("transformers")
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
|
||||||
)
|
|
||||||
def test_multiclass_classifier():
|
def test_multiclass_classifier():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -124,9 +117,6 @@ def test_multiclass_classifier():
|
|||||||
|
|
||||||
|
|
||||||
@skip_if_package_missing("transformers")
|
@skip_if_package_missing("transformers")
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
|
||||||
)
|
|
||||||
def test_default_device():
|
def test_default_device():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
@@ -139,9 +129,6 @@ def test_default_device():
|
|||||||
|
|
||||||
|
|
||||||
@skip_if_package_missing("transformers")
|
@skip_if_package_missing("transformers")
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="helper2424/resnet10 needs to be updated to work with the latest version of transformers"
|
|
||||||
)
|
|
||||||
def test_explicit_device_setup():
|
def test_explicit_device_setup():
|
||||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||||
|
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_
|
|||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
[(1, 6, 6, "lerobot/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||||
)
|
)
|
||||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||||
def test_sac_policy_with_pretrained_encoder(
|
def test_sac_policy_with_pretrained_encoder(
|
||||||
|
|||||||
Reference in New Issue
Block a user