mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 17:27:03 +00:00
fix(eval): align raw frame keys with dataset schema and fix numpy types
This commit is contained in:
@@ -85,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -122,26 +122,34 @@ def _build_raw_frame(
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index."""
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
if "pixels" in raw_obs:
|
||||
if isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
frame[f"{OBS_IMAGES}.{cam_name}"] = img[env_idx]
|
||||
else:
|
||||
frame[OBS_IMAGE] = raw_obs["pixels"][env_idx]
|
||||
if "agent_pos" in raw_obs:
|
||||
frame[OBS_STATE] = raw_obs["agent_pos"][env_idx]
|
||||
for key, val in raw_obs.items():
|
||||
if key in ("pixels", "agent_pos"):
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if isinstance(val, np.ndarray):
|
||||
frame[f"{OBS_STR}.{key}"] = val[env_idx]
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
frame[key] = raw_obs[raw_key][env_idx]
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.float32(reward)
|
||||
frame["next.success"] = success
|
||||
frame["next.done"] = done
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
@@ -290,6 +298,7 @@ def rollout(
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
|
||||
Reference in New Issue
Block a user