From e3ce2eb74349ece411b0829283eb75c68eb820dc Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 13 Oct 2025 16:12:39 +0200 Subject: [PATCH] update factory with dsrl --- src/lerobot/policies/dsrl/modeling_dsrl.py | 2 +- src/lerobot/policies/factory.py | 26 ++++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/dsrl/modeling_dsrl.py b/src/lerobot/policies/dsrl/modeling_dsrl.py index 0faf3ee73..c1c431947 100644 --- a/src/lerobot/policies/dsrl/modeling_dsrl.py +++ b/src/lerobot/policies/dsrl/modeling_dsrl.py @@ -249,7 +249,7 @@ class DSRLPolicy(PreTrainedPolicy): raise ValueError(f"Unknown model type: {model}") def update_target_networks(self): - """Update target networks with exponential moving average""" + """Update target networks of the action critic with exponential moving average""" for target_param, param in zip( self.action_critic_target.parameters(), self.action_critic_ensemble.parameters(), diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index eb6266757..68cc00df5 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -30,6 +30,7 @@ from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.dsrl.configuration_dsrl import DSRLConfig from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi05.configuration_pi05 import PI05Config @@ -59,7 +60,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: Args: name: The name of the policy. Supported names are "tdmpc", "diffusion", "act", - "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla". + "vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "dsrl". Returns: The policy class corresponding to the given name. @@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy return SmolVLAPolicy + elif name == "dsrl": + from lerobot.policies.dsrl.modeling_dsrl import DSRLPolicy + + return DSRLPolicy elif name == "groot": from lerobot.policies.groot.modeling_groot import GrootPolicy @@ -121,7 +126,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: Args: policy_type: The type of the policy. Supported types include "tdmpc", "diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla", - "reward_classifier". + "reward_classifier", "dsrl". **kwargs: Keyword arguments to be passed to the configuration class constructor. Returns: @@ -148,6 +153,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return SmolVLAConfig(**kwargs) elif policy_type == "reward_classifier": return RewardClassifierConfig(**kwargs) + elif policy_type == "dsrl": + return DSRLConfig(**kwargs) elif policy_type == "groot": return GrootConfig(**kwargs) else: @@ -321,6 +328,21 @@ def make_pre_post_processors( config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"), ) + elif isinstance(policy_cfg, DSRLConfig): + from lerobot.policies.dsrl.processor_dsrl import make_dsrl_pre_post_processors + + processors = make_dsrl_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) + + elif isinstance(policy_cfg, GrootConfig): + from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors + + processors = make_groot_pre_post_processors( + config=policy_cfg, + dataset_stats=kwargs.get("dataset_stats"), + ) elif isinstance(policy_cfg, GrootConfig): from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors