mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-21 08:59:41 +00:00
76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
from lerobot.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
|
|
|
|
torchvision.set_video_backend("pyav")
|
|
|
|
|
|
def generate_features_from_config(AgiBotWorld_CONFIG):
|
|
features = {}
|
|
for key, value in AgiBotWorld_CONFIG["images"].items():
|
|
features[f"observation.images.{key}"] = value
|
|
for key, value in AgiBotWorld_CONFIG["states"].items():
|
|
features[f"observation.states.{key}"] = value
|
|
for key, value in AgiBotWorld_CONFIG["actions"].items():
|
|
features[f"actions.{key}"] = value
|
|
return features
|
|
|
|
|
|
def sample_images(input):
|
|
if type(input) is str:
|
|
video_path = input
|
|
reader = torchvision.io.VideoReader(video_path, stream="video")
|
|
frames = [frame["data"] for frame in reader]
|
|
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
|
|
|
|
sampled_indices = sample_indices(len(frames_array))
|
|
images = None
|
|
for i, idx in enumerate(sampled_indices):
|
|
img = frames_array[idx]
|
|
img = auto_downsample_height_width(img)
|
|
|
|
if images is None:
|
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
|
|
|
images[i] = img
|
|
elif type(input) is np.ndarray:
|
|
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
|
|
sampled_indices = sample_indices(len(frames_array))
|
|
images = None
|
|
for i, idx in enumerate(sampled_indices):
|
|
img = frames_array[idx]
|
|
img = auto_downsample_height_width(img)
|
|
|
|
if images is None:
|
|
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
|
|
|
images[i] = img
|
|
|
|
return images
|
|
|
|
|
|
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
|
|
ep_stats = {}
|
|
for key, data in episode_data.items():
|
|
if features[key]["dtype"] == "string":
|
|
continue # HACK: we should receive np.arrays of strings
|
|
elif features[key]["dtype"] in ["image", "video"]:
|
|
ep_ft_array = sample_images(data)
|
|
axes_to_reduce = (0, 2, 3) # keep channel dim
|
|
keepdims = True
|
|
else:
|
|
ep_ft_array = data # data is already a np.ndarray
|
|
axes_to_reduce = 0 # compute stats over the first axis
|
|
keepdims = data.ndim == 1 # keep as np.array
|
|
|
|
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
|
|
|
if features[key]["dtype"] in ["image", "video"]:
|
|
value_norm = 1.0 if "depth" in key else 255.0
|
|
ep_stats[key] = {
|
|
k: v if k == "count" else np.squeeze(v / value_norm, axis=0) for k, v in ep_stats[key].items()
|
|
}
|
|
|
|
return ep_stats
|