mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 07:37:10 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -124,13 +124,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||
return ft
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def robot_state_feature_key(self) -> PolicyFeature | None:
|
||||
for key, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE:
|
||||
return key
|
||||
return None
|
||||
|
||||
@property
|
||||
def env_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
|
||||
@@ -14,7 +14,6 @@ from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.constants import (
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGE_2,
|
||||
)
|
||||
|
||||
|
||||
@@ -238,13 +237,13 @@ class LiberoEnv(gym.Env):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"][OBS_IMAGE]
|
||||
return image
|
||||
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
formatted = self._format_raw_obs(raw_obs)
|
||||
# grab the "main" camera
|
||||
return formatted["pixels"]["image"]
|
||||
|
||||
|
||||
def _make_envs_task(self, task_suite, task_id: int = 0):
|
||||
task = task_suite.get_task(task_id)
|
||||
self.task = task.name
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
observations: dict[str, np.ndarray], cfg: dict[str, Any] = None
|
||||
) -> dict[str, Tensor]:
|
||||
@@ -52,7 +53,7 @@ def preprocess_observation(
|
||||
imgs = {"pixels": observations["pixels"]}
|
||||
|
||||
# build rename map env_key -> policy_key
|
||||
rename_map = dict(zip(env_img_keys, policy_img_keys))
|
||||
rename_map = dict(zip(env_img_keys, policy_img_keys, strict=False))
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
target_key = rename_map.get(imgkey, imgkey)
|
||||
@@ -83,6 +84,7 @@ def preprocess_observation(
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
|
||||
Reference in New Issue
Block a user