mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
Refactor observation preprocessing to use a modular pipeline system
- Introduced `RobotPipeline` and `ObservationProcessor` for handling observation transformations. - Updated `preprocess_observation` to maintain backward compatibility while leveraging the new pipeline. - Added tests for the new processing components and ensured they match the original functionality. - Removed hardcoded logic in favor of a more flexible, composable architecture.
This commit is contained in:
@@ -30,7 +30,8 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.processor.pipeline import RobotPipeline, TransitionIndex
|
||||
from lerobot.processor.observation_processor import ObservationProcessor
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.policies.factory import (
|
||||
@@ -185,7 +186,10 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
observation, _ = env.reset(seed=train_cfg.seed)
|
||||
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation)
|
||||
obs_pipeline = RobotPipeline([ObservationProcessor()])
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
processed_transition = obs_pipeline(transition)
|
||||
observation = processed_transition[TransitionIndex.OBSERVATION]
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
|
||||
Reference in New Issue
Block a user