mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import Any
|
|||||||
import draccus
|
import draccus
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGE_2, OBS_IMAGES, OBS_STATE
|
from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||||
from lerobot.robots import RobotConfig
|
from lerobot.robots import RobotConfig
|
||||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from lerobot.configs.types import FeatureType, PolicyFeature
|
|||||||
from lerobot.envs.configs import EnvConfig
|
from lerobot.envs.configs import EnvConfig
|
||||||
from lerobot.utils.utils import get_channel_first_image_shape
|
from lerobot.utils.utils import get_channel_first_image_shape
|
||||||
|
|
||||||
|
|
||||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||||
"""Convert environment observation to LeRobot format observation.
|
"""Convert environment observation to LeRobot format observation.
|
||||||
@@ -79,6 +80,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
|
|||||||
@@ -460,6 +460,7 @@ def _compile_episode_data(
|
|||||||
|
|
||||||
return data_dict
|
return data_dict
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
def eval_main(cfg: EvalPipelineConfig):
|
def eval_main(cfg: EvalPipelineConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|||||||
Reference in New Issue
Block a user