diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index aff893bfa..c95e99d34 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -39,7 +39,7 @@ from lerobot.policies.factory import ( ) from lerobot.policies.normalize import Normalize, Unnormalize from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor.observation_processor import ObservationProcessor +from lerobot.processor.observation_processor import VanillaObservationProcessor from lerobot.processor.pipeline import RobotProcessor, TransitionIndex from lerobot.utils.random_utils import seeded_context from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats @@ -186,7 +186,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): observation, _ = env.reset(seed=train_cfg.seed) # apply transform to normalize the observations - obs_processor = RobotProcessor([ObservationProcessor()]) + obs_processor = RobotProcessor([VanillaObservationProcessor()]) transition = (observation, None, None, None, None, None, None) processed_transition = obs_processor(transition) observation = processed_transition[TransitionIndex.OBSERVATION]