mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
refactor(eval): per-env datasets recording, no double reset
- Extract _infer_shape_from_obs() to reduce nesting in feature conversion - Move dataset creation into rollout() using its own env.reset() observation, eliminating the extra reset in run_one() - Replace deepcopy with _shallow_copy_obs() for raw observation stashing - Support batch_size > 1: each parallel env records to its own dataset (single env skips the env_0/ nesting for simplicity) - One-time warning for env_features keys missing from observations - Pass recording_dir + env_features through the call chain instead of a pre-built recording_dataset object
This commit is contained in:
@@ -96,31 +96,31 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
def _infer_shape_from_obs(key: str, raw_obs: dict, fallback: tuple) -> tuple:
|
||||
"""Infer the observation shape from a raw env observation, stripping the batch dim."""
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
return raw_obs[key].shape[1:]
|
||||
if "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key in (f"{OBS_IMAGES}.{cam_name}", cam_name):
|
||||
return img.shape[1:]
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
return pixels.shape[1:]
|
||||
return fallback
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None:
|
||||
shape = _infer_shape_from_obs(key, raw_obs, shape)
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
@@ -147,6 +147,8 @@ def _build_raw_frame(
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if key.startswith("next."):
|
||||
continue
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
@@ -157,9 +159,8 @@ def _build_raw_frame(
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
val = raw_obs[key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
@@ -171,6 +172,19 @@ def _build_raw_frame(
|
||||
return frame
|
||||
|
||||
|
||||
def _shallow_copy_obs(obs: dict) -> dict:
|
||||
"""Copy observation dict, cloning only ndarray/Tensor values to avoid mutation."""
|
||||
out = {}
|
||||
for k, v in obs.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
out[k] = v.copy()
|
||||
elif isinstance(v, dict):
|
||||
out[k] = _shallow_copy_obs(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -181,7 +195,8 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -222,9 +237,31 @@ def rollout(
|
||||
if render_callback is not None:
|
||||
render_callback(env)
|
||||
|
||||
raw_observation = deepcopy(observation) if recording_dataset is not None else None
|
||||
recording_datasets: list[LeRobotDataset] | None = None
|
||||
raw_observation = None
|
||||
task_desc = ""
|
||||
if recording_dataset is not None:
|
||||
if recording_dir is not None and env_features is not None:
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=observation)
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
recording_datasets = []
|
||||
for i in range(env.num_envs):
|
||||
root = str(recording_dir / f"env_{i}") if env.num_envs > 1 else str(recording_dir)
|
||||
recording_datasets.append(
|
||||
LeRobotDataset.create(
|
||||
repo_id="eval_recording",
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=True,
|
||||
)
|
||||
)
|
||||
raw_observation = _shallow_copy_obs(observation)
|
||||
obs_keys = set(observation.keys())
|
||||
if "pixels" in obs_keys and isinstance(observation["pixels"], dict):
|
||||
obs_keys.update(f"{OBS_IMAGES}.{c}" for c in observation["pixels"])
|
||||
missing = [k for k in env_features if k != ACTION and not k.startswith("next.") and k not in obs_keys]
|
||||
if missing:
|
||||
logging.warning("Recording: env_features keys %s not found in env observations.", missing)
|
||||
try:
|
||||
task_desc = list(env.call("task_description"))[0]
|
||||
except (AttributeError, NotImplementedError):
|
||||
@@ -302,7 +339,7 @@ def rollout(
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
if recording_dataset is not None and raw_observation is not None:
|
||||
if recording_datasets is not None and raw_observation is not None:
|
||||
prev_done = done.copy()
|
||||
for env_idx in range(env.num_envs):
|
||||
if prev_done[env_idx]:
|
||||
@@ -315,12 +352,12 @@ def rollout(
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
recording_datasets[env_idx].features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
recording_datasets[env_idx].add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
recording_dataset.save_episode()
|
||||
raw_observation = deepcopy(observation)
|
||||
recording_datasets[env_idx].save_episode()
|
||||
raw_observation = _shallow_copy_obs(observation)
|
||||
|
||||
# Keep track of which environments are done so far.
|
||||
# Mark the episode as done if we reach the maximum step limit.
|
||||
@@ -360,6 +397,10 @@ def rollout(
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret[OBS_STR] = stacked_observations
|
||||
|
||||
if recording_datasets is not None:
|
||||
for ds in recording_datasets:
|
||||
ds.finalize()
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
policy.use_original_modules()
|
||||
|
||||
@@ -378,7 +419,8 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
recording_dataset: Any | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -467,7 +509,8 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
recording_dataset=recording_dataset,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -732,7 +775,8 @@ def eval_one(
|
||||
videos_dir: Path | None,
|
||||
return_episode_data: bool,
|
||||
start_seed: int | None,
|
||||
recording_dataset: Any | None = None,
|
||||
recording_dir: Path | None = None,
|
||||
env_features: dict | None = None,
|
||||
) -> TaskMetrics:
|
||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||
|
||||
@@ -750,7 +794,8 @@ def eval_one(
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
recording_dir=recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
per_episode = task_result["per_episode"]
|
||||
@@ -790,38 +835,25 @@ def run_one(
|
||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
recording_dataset = None
|
||||
task_recording_dir = None
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
features=features,
|
||||
root=str(task_recording_dir),
|
||||
use_videos=True,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dataset=recording_dataset,
|
||||
)
|
||||
finally:
|
||||
if recording_dataset is not None:
|
||||
recording_dataset.finalize()
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
max_episodes_rendered=max_episodes_rendered,
|
||||
videos_dir=task_videos_dir,
|
||||
return_episode_data=return_episode_data,
|
||||
start_seed=start_seed,
|
||||
recording_dir=task_recording_dir,
|
||||
env_features=env_features,
|
||||
)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
metrics.setdefault("video_paths", [])
|
||||
|
||||
Reference in New Issue
Block a user