mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
refactor(policies): Enhance processor creation and add NaN detection hook
This commit is contained in:
committed by
Steven Palma
parent
fc74001202
commit
670a278cbc
@@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -34,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
|||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.processor.pipeline import RobotProcessor
|
from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
@@ -113,13 +114,13 @@ def make_processor(
|
|||||||
Each policy type has its own processor with specific preprocessing steps.
|
Each policy type has its own processor with specific preprocessing steps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_type: The type of policy to create a processor for (e.g., "act", "diffusion", etc.)
|
policy_cfg: The config of the policy to create a processor for (e.g., "act", "diffusion", etc.)
|
||||||
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
pretrained_path: Optional path to load a pretrained processor from. If provided, loads
|
||||||
the processor from this path instead of creating a new one.
|
the processor from this path instead of creating a new one.
|
||||||
**kwargs: Additional keyword arguments passed to the processor creation.
|
**kwargs: Additional keyword arguments passed to the processor creation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
RobotProcessor: The configured processor instance.
|
Tuple of (input_processor, output_processor) for the policy.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||||
@@ -133,51 +134,67 @@ def make_processor(
|
|||||||
if policy_cfg.type == "tdmpc":
|
if policy_cfg.type == "tdmpc":
|
||||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_processor
|
||||||
|
|
||||||
return make_tdmpc_processor(policy_cfg, **kwargs)
|
processors = make_tdmpc_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "diffusion":
|
elif policy_cfg.type == "diffusion":
|
||||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_processor
|
||||||
|
|
||||||
return make_diffusion_processor(policy_cfg, **kwargs)
|
processors = make_diffusion_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "act":
|
elif policy_cfg.type == "act":
|
||||||
from lerobot.policies.act.processor_act import make_act_processor
|
from lerobot.policies.act.processor_act import make_act_processor
|
||||||
|
|
||||||
return make_act_processor(policy_cfg, **kwargs)
|
processors = make_act_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "vqbet":
|
elif policy_cfg.type == "vqbet":
|
||||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_processor
|
||||||
|
|
||||||
return make_vqbet_processor(policy_cfg, **kwargs)
|
processors = make_vqbet_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "pi0":
|
elif policy_cfg.type == "pi0":
|
||||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
from lerobot.policies.pi0.processor_pi0 import make_pi0_processor
|
||||||
|
|
||||||
return make_pi0_processor(policy_cfg, **kwargs)
|
processors = make_pi0_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "pi0fast":
|
elif policy_cfg.type == "pi0fast":
|
||||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_processor
|
||||||
|
|
||||||
return make_pi0fast_processor(policy_cfg, **kwargs)
|
processors = make_pi0fast_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "sac":
|
elif policy_cfg.type == "sac":
|
||||||
from lerobot.policies.sac.processor_sac import make_sac_processor
|
from lerobot.policies.sac.processor_sac import make_sac_processor
|
||||||
|
|
||||||
return make_sac_processor(policy_cfg, **kwargs)
|
processors = make_sac_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "reward_classifier":
|
elif policy_cfg.type == "reward_classifier":
|
||||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||||
|
|
||||||
return make_classifier_processor(policy_cfg, **kwargs)
|
processors = make_classifier_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
elif policy_cfg.type == "smolvla":
|
elif policy_cfg.type == "smolvla":
|
||||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_processor
|
||||||
|
|
||||||
return make_smolvla_processor(policy_cfg, **kwargs)
|
processors = make_smolvla_processor(policy_cfg, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||||
|
|
||||||
|
# Helper hook function to detect NaNs in observation
|
||||||
|
def nan_detection_hook(step_idx: int, transition: EnvTransition) -> None:
|
||||||
|
observation = transition[TransitionIndex.OBSERVATION]
|
||||||
|
if observation is not None:
|
||||||
|
for key, value in observation.items():
|
||||||
|
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
|
||||||
|
logging.warning(f"NaN detected in observation key '{key}' after step {step_idx}: {value}")
|
||||||
|
|
||||||
|
# Attach the hook to all returned processors
|
||||||
|
if isinstance(processors, RobotProcessor):
|
||||||
|
processors = (processors,) # Wrap single processor in tuple for consistency
|
||||||
|
for processor in processors:
|
||||||
|
processor.register_after_step_hook(nan_detection_hook)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
|
||||||
def make_policy(
|
def make_policy(
|
||||||
cfg: PreTrainedConfig,
|
cfg: PreTrainedConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user