mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| acd31c7de2 |
@@ -72,7 +72,7 @@ from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs import FeatureType, parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs import (
|
||||
@@ -85,7 +85,7 @@ from lerobot.envs import (
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.types import PolicyAction
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
@@ -96,18 +96,32 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create()."""
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
if ft.type.value == "visual":
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"shape": tuple(ft.shape),
|
||||
"names": ["channel", "height", "width"],
|
||||
}
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
features[key] = {"dtype": "float32", "shape": tuple(ft.shape), "names": None}
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
@@ -122,26 +136,37 @@ def _build_raw_frame(
|
||||
success: bool,
|
||||
done: bool,
|
||||
task: str,
|
||||
env_features: dict,
|
||||
) -> dict:
|
||||
"""Build a dataset frame from raw env observations for one env index."""
|
||||
"""Build a dataset frame from raw env observations for one env index.
|
||||
|
||||
Keys in the frame match the keys in env_features so they align with the
|
||||
dataset schema created by _env_features_to_dataset_features().
|
||||
"""
|
||||
frame: dict[str, Any] = {}
|
||||
if "pixels" in raw_obs:
|
||||
if isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
frame[f"{OBS_IMAGES}.{cam_name}"] = img[env_idx]
|
||||
else:
|
||||
frame[OBS_IMAGE] = raw_obs["pixels"][env_idx]
|
||||
if "agent_pos" in raw_obs:
|
||||
frame[OBS_STATE] = raw_obs["agent_pos"][env_idx]
|
||||
for key, val in raw_obs.items():
|
||||
if key in ("pixels", "agent_pos"):
|
||||
for key in env_features:
|
||||
if key == ACTION:
|
||||
continue
|
||||
if isinstance(val, np.ndarray):
|
||||
frame[f"{OBS_STR}.{key}"] = val[env_idx]
|
||||
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||
for cam_name, img in raw_obs["pixels"].items():
|
||||
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||
if candidate == key:
|
||||
frame[key] = img[env_idx]
|
||||
if key in frame:
|
||||
continue
|
||||
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||
frame[key] = raw_obs["pixels"][env_idx]
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.float32(reward)
|
||||
frame["next.success"] = success
|
||||
frame["next.done"] = done
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||
frame["task"] = task
|
||||
return frame
|
||||
|
||||
@@ -290,6 +315,7 @@ def rollout(
|
||||
successes[env_idx],
|
||||
bool(terminated[env_idx] | truncated[env_idx]),
|
||||
task_desc,
|
||||
recording_dataset.features,
|
||||
)
|
||||
recording_dataset.add_frame(frame)
|
||||
if terminated[env_idx] or truncated[env_idx]:
|
||||
@@ -768,7 +794,8 @@ def run_one(
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
|
||||
Reference in New Issue
Block a user