mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix(test): policies
This commit is contained in:
@@ -39,7 +39,7 @@ from lerobot.policies.factory import (
|
|||||||
)
|
)
|
||||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.processor.observation_processor import ObservationProcessor
|
from lerobot.processor.observation_processor import VanillaObservationProcessor
|
||||||
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
from lerobot.processor.pipeline import RobotProcessor, TransitionIndex
|
||||||
from lerobot.utils.random_utils import seeded_context
|
from lerobot.utils.random_utils import seeded_context
|
||||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||||
@@ -186,7 +186,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
observation, _ = env.reset(seed=train_cfg.seed)
|
observation, _ = env.reset(seed=train_cfg.seed)
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
obs_processor = RobotProcessor([ObservationProcessor()])
|
obs_processor = RobotProcessor([VanillaObservationProcessor()])
|
||||||
transition = (observation, None, None, None, None, None, None)
|
transition = (observation, None, None, None, None, None, None)
|
||||||
processed_transition = obs_processor(transition)
|
processed_transition = obs_processor(transition)
|
||||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||||
|
|||||||
Reference in New Issue
Block a user