mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
ad1381915c
* update agibot2lerobot * update libero2lerobot * update robomind2lerobot * fix robomind2lerobot
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import numpy as np
|
|
from lerobot.datasets.compute_stats import (
|
|
DEFAULT_QUANTILES,
|
|
auto_downsample_height_width,
|
|
get_feature_stats,
|
|
sample_indices,
|
|
)
|
|
from lerobot.datasets.io_utils import load_image_as_numpy
|
|
|
|
|
|
def generate_features_from_config(AgiBotWorld_CONFIG):
|
|
features = {}
|
|
for key, value in AgiBotWorld_CONFIG["images"].items():
|
|
features[f"observation.images.{key}"] = value
|
|
for key, value in AgiBotWorld_CONFIG["states"].items():
|
|
features[f"observation.states.{key}"] = value
|
|
for key, value in AgiBotWorld_CONFIG["actions"].items():
|
|
features[f"actions.{key}"] = value
|
|
return features
|
|
|
|
|
|
def sample_images(input):
|
|
if type(input) is list:
|
|
image_paths = input
|
|
|
|
sampled_indices = sample_indices(len(image_paths))
|
|
images = None
|
|
for i, idx in enumerate(sampled_indices):
|
|
path = image_paths[idx]
|
|
|
|
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
|
|
img = auto_downsample_height_width(img)
|
|
|
|
if images is None:
|
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
|
|
|
images[i] = img
|
|
elif type(input) is np.ndarray:
|
|
frames_array = input[:, None, :, :] # Shape: [T, 1, H, W]
|
|
sampled_indices = sample_indices(len(frames_array))
|
|
images = None
|
|
for i, idx in enumerate(sampled_indices):
|
|
img = frames_array[idx]
|
|
img = auto_downsample_height_width(img)
|
|
|
|
if images is None:
|
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
|
|
|
images[i] = img
|
|
|
|
return images
|
|
|
|
|
|
def compute_episode_stats(
|
|
episode_data: dict[str, list[str] | np.ndarray],
|
|
features: dict,
|
|
quantile_list: list[float] | None = None,
|
|
) -> dict:
|
|
if quantile_list is None:
|
|
quantile_list = DEFAULT_QUANTILES
|
|
|
|
ep_stats = {}
|
|
for key, data in episode_data.items():
|
|
if features[key]["dtype"] == "string":
|
|
continue
|
|
|
|
elif features[key]["dtype"] in ["image", "video"]:
|
|
ep_ft_array = sample_images(data)
|
|
axes_to_reduce = (0, 2, 3)
|
|
keepdims = True
|
|
else:
|
|
ep_ft_array = data
|
|
axes_to_reduce = 0
|
|
keepdims = data.ndim == 1
|
|
|
|
ep_stats[key] = get_feature_stats(
|
|
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
|
)
|
|
|
|
if features[key]["dtype"] in ["image", "video"]:
|
|
value_norm = 1.0 if "depth" in key else 255.0
|
|
ep_stats[key] = {
|
|
k: v if k == "count" else np.squeeze(v / value_norm, axis=0) for k, v in ep_stats[key].items()
|
|
}
|
|
|
|
return ep_stats
|