mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
refactor(factory): Remove unused imports and NaN detection hook from processor creation
This commit is contained in:
committed by
Steven Palma
parent
8b4a5368b3
commit
21baa8fa02
@@ -16,7 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -35,7 +34,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
|||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex
|
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor
|
||||||
|
|
||||||
|
|
||||||
def get_policy_class(name: str) -> PreTrainedPolicy:
|
def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||||
@@ -181,20 +180,6 @@ def make_processor(
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||||
|
|
||||||
# Helper hook function to detect NaNs in observation
|
|
||||||
def nan_detection_hook(step_idx: int, transition: EnvTransition) -> None:
|
|
||||||
observation = transition[TransitionIndex.OBSERVATION]
|
|
||||||
if observation is not None:
|
|
||||||
for key, value in observation.items():
|
|
||||||
if isinstance(value, torch.Tensor) and torch.isnan(value).any():
|
|
||||||
logging.warning(f"NaN detected in observation key '{key}' after step {step_idx}: {value}")
|
|
||||||
|
|
||||||
# Attach the hook to all returned processors
|
|
||||||
if isinstance(processors, RobotProcessor):
|
|
||||||
processors = (processors,) # Wrap single processor in tuple for consistency
|
|
||||||
for processor in processors:
|
|
||||||
processor.register_after_step_hook(nan_detection_hook)
|
|
||||||
|
|
||||||
return processors
|
return processors
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user