refactor(eval): remove shape inference and shallow copy helpers

This commit is contained in:
Khalil Meftah
2026-06-16 22:13:23 +02:00
parent afeeeb8982
commit 4f5e6596be
+4 -40
View File
@@ -96,28 +96,11 @@ from lerobot.utils.utils import (
)
def _infer_shape_from_obs(key: str, raw_obs: dict, fallback: tuple) -> tuple:
"""Infer the observation shape from a raw env observation, stripping the batch dim."""
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
return raw_obs[key].shape[1:]
if "pixels" in raw_obs:
pixels = raw_obs["pixels"]
if isinstance(pixels, dict):
for cam_name, img in pixels.items():
if key in (f"{OBS_IMAGES}.{cam_name}", cam_name):
return img.shape[1:]
elif key in ("pixels", OBS_IMAGE):
return pixels.shape[1:]
return fallback
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
def _env_features_to_dataset_features(env_features: dict) -> dict:
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
features = {}
for key, ft in env_features.items():
shape = tuple(ft.shape)
if raw_obs is not None:
shape = _infer_shape_from_obs(key, raw_obs, shape)
if ft.type is FeatureType.VISUAL:
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
@@ -172,19 +155,6 @@ def _build_raw_frame(
return frame
def _shallow_copy_obs(obs: dict) -> dict:
"""Copy observation dict, cloning only ndarray/Tensor values to avoid mutation."""
out = {}
for k, v in obs.items():
if isinstance(v, np.ndarray):
out[k] = v.copy()
elif isinstance(v, dict):
out[k] = _shallow_copy_obs(v)
else:
out[k] = v
return out
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -241,7 +211,7 @@ def rollout(
raw_observation = None
task_desc = ""
if recording_dir is not None and env_features is not None:
features = _env_features_to_dataset_features(env_features, raw_obs=observation)
features = _env_features_to_dataset_features(env_features)
fps = env.unwrapped.metadata.get("render_fps", 30)
recording_datasets = []
for i in range(env.num_envs):
@@ -255,13 +225,7 @@ def rollout(
use_videos=True,
)
)
raw_observation = _shallow_copy_obs(observation)
obs_keys = set(observation.keys())
if "pixels" in obs_keys and isinstance(observation["pixels"], dict):
obs_keys.update(f"{OBS_IMAGES}.{c}" for c in observation["pixels"])
missing = [k for k in env_features if k != ACTION and not k.startswith("next.") and k not in obs_keys]
if missing:
logging.warning("Recording: env_features keys %s not found in env observations.", missing)
raw_observation = deepcopy(observation)
try:
task_desc = list(env.call("task_description"))[0]
except (AttributeError, NotImplementedError):
@@ -357,7 +321,7 @@ def rollout(
recording_datasets[env_idx].add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_datasets[env_idx].save_episode()
raw_observation = _shallow_copy_obs(observation)
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.