mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
refactor(eval): remove shape inference and shallow copy helpers
This commit is contained in:
@@ -96,28 +96,11 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _infer_shape_from_obs(key: str, raw_obs: dict, fallback: tuple) -> tuple:
|
||||
"""Infer the observation shape from a raw env observation, stripping the batch dim."""
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
return raw_obs[key].shape[1:]
|
||||
if "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key in (f"{OBS_IMAGES}.{cam_name}", cam_name):
|
||||
return img.shape[1:]
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
return pixels.shape[1:]
|
||||
return fallback
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None:
|
||||
shape = _infer_shape_from_obs(key, raw_obs, shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
@@ -172,19 +155,6 @@ def _build_raw_frame(
|
||||
return frame
|
||||
|
||||
|
||||
def _shallow_copy_obs(obs: dict) -> dict:
|
||||
"""Copy observation dict, cloning only ndarray/Tensor values to avoid mutation."""
|
||||
out = {}
|
||||
for k, v in obs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
out[k] = v.copy()
|
||||
elif isinstance(v, dict):
|
||||
out[k] = _shallow_copy_obs(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -241,7 +211,7 @@ def rollout(
|
||||
raw_observation = None
|
||||
task_desc = ""
|
||||
if recording_dir is not None and env_features is not None:
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=observation)
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
recording_datasets = []
|
||||
for i in range(env.num_envs):
|
||||
@@ -255,13 +225,7 @@ def rollout(
|
||||
use_videos=True,
|
||||
)
|
||||
)
|
||||
raw_observation = _shallow_copy_obs(observation)
|
||||
obs_keys = set(observation.keys())
|
||||
if "pixels" in obs_keys and isinstance(observation["pixels"], dict):
|
||||
obs_keys.update(f"{OBS_IMAGES}.{c}" for c in observation["pixels"])
|
||||
missing = [k for k in env_features if k != ACTION and not k.startswith("next.") and k not in obs_keys]
|
||||
if missing:
|
||||
logging.warning("Recording: env_features keys %s not found in env observations.", missing)
|
||||
raw_observation = deepcopy(observation)
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
@@ -357,7 +321,7 @@ def rollout(
|
||||
recording_datasets[env_idx].add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_datasets[env_idx].save_episode()
|
||||
raw_observation = _shallow_copy_obs(observation)
|
||||
raw_observation = deepcopy(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
|
||||
Reference in New Issue
Block a user