Change SAC policy implementation with configuration and modeling classes

This commit is contained in:
Adil Zouitine
2025-01-17 09:39:04 +01:00
committed by Michel Aractingi
parent 8105efb338
commit 7d2970fdfe
4 changed files with 55 additions and 718 deletions
+4 -5
View File
@@ -71,7 +71,6 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
from lerobot.common.policies.sac.modeling_sac import SACPolicy
return SACPolicy, SACConfig
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
@@ -91,10 +90,10 @@ def make_policy(
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
# raise ValueError(
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
# )
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)