refactor(policies): Enhance processor creation and add NaN detection hook

This commit is contained in:
AdilZouitine
2025-07-10 18:50:25 +02:00
committed by Steven Palma
parent fc74001202
commit 670a278cbc
+29 -12
View File
@@ -16,6 +16,7 @@
import logging
import torch
from torch import nn
from lerobot.configs.policies import PreTrainedConfig
@@ -34,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.processor.pipeline import RobotProcessor
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
def get_policy_class(name: str) -> PreTrainedPolicy:
@@ -113,13 +114,13 @@ def make_processor(
Each policy type has its own processor with specific preprocessing steps.
Args:
policy_type: The type of policy to create a processor for (e.g., "act", "diffusion", etc.)
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
the processor from this path instead of creating a new one.
**kwargs: Additional keyword arguments passed to the processor creation.
Returns:
RobotProcessor: The configured processor instance.
Tuple of (input_processor, output_processor) for the policy.
Raises:
NotImplementedError: If the policy type doesn't have a processor implemented.
@@ -133,51 +134,67 @@ def make_processor(
if policy_cfg.type == "tdmpc":
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
return make_tdmpc_processor(policy_cfg, **kwargs)
processors = make_tdmpc_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "diffusion":
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
return make_diffusion_processor(policy_cfg, **kwargs)
processors = make_diffusion_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "act":
from lerobot.policies.act.processor_act import make_act_processor
return make_act_processor(policy_cfg, **kwargs)
processors = make_act_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "vqbet":
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
return make_vqbet_processor(policy_cfg, **kwargs)
processors = make_vqbet_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "pi0":
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
return make_pi0_processor(policy_cfg, **kwargs)
processors = make_pi0_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "pi0fast":
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
return make_pi0fast_processor(policy_cfg, **kwargs)
processors = make_pi0fast_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "sac":
from lerobot.policies.sac.processor_sac import make_sac_processor
return make_sac_processor(policy_cfg, **kwargs)
processors = make_sac_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "reward_classifier":
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
return make_classifier_processor(policy_cfg, **kwargs)
processors = make_classifier_processor(policy_cfg, **kwargs)
elif policy_cfg.type == "smolvla":
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
return make_smolvla_processor(policy_cfg, **kwargs)
processors = make_smolvla_processor(policy_cfg, **kwargs)
else:
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
# Helper hook function to detect NaNs in observation
def nan_detection_hook(step_idx: int, transition: EnvTransition) -> None:
observation = transition[TransitionIndex.OBSERVATION]
if observation is not None:
for key, value in observation.items():
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
logging.warning(f"NaN detected in observation key '{key}' after step {step_idx}: {value}")
# Attach the hook to all returned processors
if isinstance(processors, RobotProcessor):
processors = (processors,) # Wrap single processor in tuple for consistency
for processor in processors:
processor.register_after_step_hook(nan_detection_hook)
return processors
def make_policy(
cfg: PreTrainedConfig,