fix(test): policies

This commit is contained in:
Adil Zouitine
2025-07-04 11:14:14 +02:00
parent 2a7a0e6129
commit e2fcd140b0
+2 -2
View File
@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
)
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor.observation_processor import ObservationProcessor
from lerobot.processor.observation_processor import VanillaObservationProcessor
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
from lerobot.utils.random_utils import seeded_context
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
@@ -186,7 +186,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
observation, _ = env.reset(seed=train_cfg.seed)
# apply transform to normalize the observations
obs_processor = RobotProcessor([ObservationProcessor()])
obs_processor = RobotProcessor([VanillaObservationProcessor()])
transition = (observation, None, None, None, None, None, None)
processed_transition = obs_processor(transition)
observation = processed_transition[TransitionIndex.OBSERVATION]