mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-24 02:09:40 +00:00
⬆️ sync with lerobot v0.5.1 (#96)
* update agibot2lerobot * update libero2lerobot * update robomind2lerobot * fix robomind2lerobot
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from lerobot.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
|
||||
from lerobot.datasets.utils import load_image_as_numpy
|
||||
|
||||
torchvision.set_video_backend("pyav")
|
||||
from lerobot.datasets.compute_stats import (
|
||||
DEFAULT_QUANTILES,
|
||||
auto_downsample_height_width,
|
||||
get_feature_stats,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.datasets.io_utils import load_image_as_numpy
|
||||
|
||||
|
||||
def generate_features_from_config(AgiBotWorld_CONFIG):
|
||||
@@ -49,21 +51,31 @@ def sample_images(input):
|
||||
return images
|
||||
|
||||
|
||||
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
||||
def compute_episode_stats(
|
||||
episode_data: dict[str, list[str] | np.ndarray],
|
||||
features: dict,
|
||||
quantile_list: list[float] | None = None,
|
||||
) -> dict:
|
||||
if quantile_list is None:
|
||||
quantile_list = DEFAULT_QUANTILES
|
||||
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue # HACK: we should receive np.arrays of strings
|
||||
continue
|
||||
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
axes_to_reduce = (0, 2, 3)
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
ep_ft_array = data
|
||||
axes_to_reduce = 0
|
||||
keepdims = data.ndim == 1
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
ep_stats[key] = get_feature_stats(
|
||||
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
value_norm = 1.0 if "depth" in key else 255.0
|
||||
|
||||
Reference in New Issue
Block a user