import numpy as np from lerobot.datasets.compute_stats import ( DEFAULT_QUANTILES, auto_downsample_height_width, get_feature_stats, sample_indices, ) from torchcodec.decoders import VideoDecoder def generate_features_from_config(AgiBotWorld_CONFIG): features = {} for key, value in AgiBotWorld_CONFIG["images"].items(): features[f"observation.images.{key}"] = value for key, value in AgiBotWorld_CONFIG["states"].items(): features[f"observation.states.{key}"] = value for key, value in AgiBotWorld_CONFIG["actions"].items(): features[f"actions.{key}"] = value return features def sample_images(input): if type(input) is str: video_path = input decoder = VideoDecoder(video_path) frames_array = decoder[0:-1].numpy() # Shape: [T, C, H, W] sampled_indices = sample_indices(len(frames_array)) images = None for i, idx in enumerate(sampled_indices): img = frames_array[idx] img = auto_downsample_height_width(img) if images is None: images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) images[i] = img elif type(input) is np.ndarray: frames_array = input[:, None, :, :] # Shape: [T, C, H, W] sampled_indices = sample_indices(len(frames_array)) images = None for i, idx in enumerate(sampled_indices): img = frames_array[idx] img = auto_downsample_height_width(img) if images is None: images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) images[i] = img return images def compute_episode_stats( episode_data: dict[str, list[str] | np.ndarray], features: dict, quantile_list: list[float] | None = None, ) -> dict: if quantile_list is None: quantile_list = DEFAULT_QUANTILES ep_stats = {} for key, data in episode_data.items(): if features[key]["dtype"] == "string": continue elif features[key]["dtype"] in ["image", "video"]: ep_ft_array = sample_images(data) axes_to_reduce = (0, 2, 3) keepdims = True else: ep_ft_array = data axes_to_reduce = 0 keepdims = data.ndim == 1 ep_stats[key] = get_feature_stats( ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list ) if features[key]["dtype"] in ["image", "video"]: value_norm = 1.0 if "depth" in key else 255.0 ep_stats[key] = { k: v if k == "count" else np.squeeze(v / value_norm, axis=0) for k, v in ep_stats[key].items() } return ep_stats