diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index b3fdf0626..9f250d72b 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -16,6 +16,7 @@ import logging +import torch from torch import nn from lerobot.configs.policies import PreTrainedConfig @@ -34,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.processor.pipeline import RobotProcessor +from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex def get_policy_class(name: str) -> PreTrainedPolicy: @@ -113,13 +114,13 @@ def make_processor( Each policy type has its own processor with specific preprocessing steps. Args: - policy_type: The type of policy to create a processor for (e.g., "act", "diffusion", etc.) + policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.) pretrained_path: Optional path to load a pretrained processor from. If provided, loads the processor from this path instead of creating a new one. **kwargs: Additional keyword arguments passed to the processor creation. Returns: - RobotProcessor: The configured processor instance. + Tuple of (input_processor, output_processor) for the policy. Raises: NotImplementedError: If the policy type doesn't have a processor implemented. @@ -133,51 +134,67 @@ def make_processor( if policy_cfg.type == "tdmpc": from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor - return make_tdmpc_processor(policy_cfg, **kwargs) + processors = make_tdmpc_processor(policy_cfg, **kwargs) elif policy_cfg.type == "diffusion": from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor - return make_diffusion_processor(policy_cfg, **kwargs) + processors = make_diffusion_processor(policy_cfg, **kwargs) elif policy_cfg.type == "act": from lerobot.policies.act.processor_act import make_act_processor - return make_act_processor(policy_cfg, **kwargs) + processors = make_act_processor(policy_cfg, **kwargs) elif policy_cfg.type == "vqbet": from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor - return make_vqbet_processor(policy_cfg, **kwargs) + processors = make_vqbet_processor(policy_cfg, **kwargs) elif policy_cfg.type == "pi0": from lerobot.policies.pi0.processor_pi0 import make_pi0_processor - return make_pi0_processor(policy_cfg, **kwargs) + processors = make_pi0_processor(policy_cfg, **kwargs) elif policy_cfg.type == "pi0fast": from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor - return make_pi0fast_processor(policy_cfg, **kwargs) + processors = make_pi0fast_processor(policy_cfg, **kwargs) elif policy_cfg.type == "sac": from lerobot.policies.sac.processor_sac import make_sac_processor - return make_sac_processor(policy_cfg, **kwargs) + processors = make_sac_processor(policy_cfg, **kwargs) elif policy_cfg.type == "reward_classifier": from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor - return make_classifier_processor(policy_cfg, **kwargs) + processors = make_classifier_processor(policy_cfg, **kwargs) elif policy_cfg.type == "smolvla": from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor - return make_smolvla_processor(policy_cfg, **kwargs) + processors = make_smolvla_processor(policy_cfg, **kwargs) else: raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") + # Helper hook function to detect NaNs in observation + def nan_detection_hook(step_idx: int, transition: EnvTransition) -> None: + observation = transition[TransitionIndex.OBSERVATION] + if observation is not None: + for key, value in observation.items(): + if isinstance(value, torch.Tensor) and torch.isnan(value).any(): + logging.warning(f"NaN detected in observation key '{key}' after step {step_idx}: {value}") + + # Attach the hook to all returned processors + if isinstance(processors, RobotProcessor): + processors = (processors,) # Wrap single processor in tuple for consistency + for processor in processors: + processor.register_after_step_hook(nan_detection_hook) + + return processors + def make_policy( cfg: PreTrainedConfig,