mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
refactor(policies): Enhance processor creation and add NaN detection hook
This commit is contained in:
committed by
Steven Palma
parent
fc74001202
commit
670a278cbc
@@ -16,6 +16,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -34,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
@@ -113,13 +114,13 @@ def make_processor(
|
||||
Each policy type has its own processor with specific preprocessing steps.
|
||||
|
||||
Args:
|
||||
policy_type: The type of policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||
the processor from this path instead of creating a new one.
|
||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||
|
||||
Returns:
|
||||
RobotProcessor: The configured processor instance.
|
||||
Tuple of (input_processor, output_processor) for the policy.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
@@ -133,51 +134,67 @@ def make_processor(
|
||||
if policy_cfg.type == "tdmpc":
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||
|
||||
return make_tdmpc_processor(policy_cfg, **kwargs)
|
||||
processors = make_tdmpc_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "diffusion":
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||
|
||||
return make_diffusion_processor(policy_cfg, **kwargs)
|
||||
processors = make_diffusion_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "act":
|
||||
from lerobot.policies.act.processor_act import make_act_processor
|
||||
|
||||
return make_act_processor(policy_cfg, **kwargs)
|
||||
processors = make_act_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "vqbet":
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||
|
||||
return make_vqbet_processor(policy_cfg, **kwargs)
|
||||
processors = make_vqbet_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "pi0":
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
||||
|
||||
return make_pi0_processor(policy_cfg, **kwargs)
|
||||
processors = make_pi0_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "pi0fast":
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
||||
|
||||
return make_pi0fast_processor(policy_cfg, **kwargs)
|
||||
processors = make_pi0fast_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "sac":
|
||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||
|
||||
return make_sac_processor(policy_cfg, **kwargs)
|
||||
processors = make_sac_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "reward_classifier":
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
return make_classifier_processor(policy_cfg, **kwargs)
|
||||
processors = make_classifier_processor(policy_cfg, **kwargs)
|
||||
|
||||
elif policy_cfg.type == "smolvla":
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
||||
|
||||
return make_smolvla_processor(policy_cfg, **kwargs)
|
||||
processors = make_smolvla_processor(policy_cfg, **kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
|
||||
# Helper hook function to detect NaNs in observation
|
||||
def nan_detection_hook(step_idx: int, transition: EnvTransition) -> None:
|
||||
observation = transition[TransitionIndex.OBSERVATION]
|
||||
if observation is not None:
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
|
||||
logging.warning(f"NaN detected in observation key '{key}' after step {step_idx}: {value}")
|
||||
|
||||
# Attach the hook to all returned processors
|
||||
if isinstance(processors, RobotProcessor):
|
||||
processors = (processors,) # Wrap single processor in tuple for consistency
|
||||
for processor in processors:
|
||||
processor.register_after_step_hook(nan_detection_hook)
|
||||
|
||||
return processors
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
|
||||
Reference in New Issue
Block a user