fix: changes to compute stats and modeling

This commit is contained in:
danaaubakirova
2025-07-11 15:50:22 +02:00
parent 008b592545
commit 7848b15bfb
6 changed files with 86 additions and 10 deletions
+15
View File
@@ -37,6 +37,21 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
# Multi-dataset support
sampling_weights: str | None = None
max_action_dim: int | None = None
max_state_dim: int | None = None
max_num_images: int | None = None
max_image_dim: int | None = None
train_on_all_features: bool = False
features_version: int = 0
discard_first_n_frames: int = 0
min_fps: int = 1
max_fps: int = 100
discard_first_idle_frames: bool = False
motion_threshold: float = 5e-2
motion_window_size: int = 10
motion_buffer: int = 3
@dataclass
+26 -5
View File
@@ -125,9 +125,30 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
# Filter out stats that don't have required keys
valid_stats = []
for s in stats_ft_list:
if all(key in s for key in ["mean", "std", "count", "min", "max"]):
valid_stats.append(s)
else:
# If count is missing, add it with a default value
if "count" not in s:
s["count"] = np.array([1]) # Default count
valid_stats.append(s)
if not valid_stats:
# If no valid stats, return empty stats
return {
"min": np.array([0]),
"max": np.array([0]),
"mean": np.array([0]),
"std": np.array([0]),
"count": np.array([0]),
}
means = np.stack([s["mean"] for s in valid_stats])
variances = np.stack([s["std"] ** 2 for s in valid_stats])
counts = np.stack([s["count"] for s in valid_stats])
total_count = counts.sum(axis=0)
# Prepare weighted mean by matching number of dimensions
@@ -144,8 +165,8 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"min": np.min(np.stack([s["min"] for s in valid_stats]), axis=0),
"max": np.max(np.stack([s["max"] for s in valid_stats]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
+3 -1
View File
@@ -32,7 +32,7 @@ IMAGENET_STATS = {
"std": [[[0.229]], [[0.224]], [[0.225]]], # (c,1,1)
}
from lerobot.common.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
from lerobot.datasets.utils_must import EPISODES_DATASET_MAPPING, FEATURE_KEYS_MAPPING
def resolve_delta_timestamps(
@@ -106,6 +106,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
image_transforms=image_transforms,
revision=revision,
video_backend=cfg.dataset.video_backend,
download_videos=True,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.dataset.max_action_dim,
max_state_dim=cfg.dataset.max_state_dim,
@@ -132,6 +133,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=cfg.dataset.video_backend,
download_videos=True,
sampling_weights=sampling_weights,
feature_keys_mapping=feature_keys_mapping,
max_action_dim=cfg.policy.max_action_dim,
@@ -65,13 +65,13 @@ from torch import Tensor, nn
from transformers import AutoProcessor
from lerobot.constants import ACTION, OBS_STATE
from lerobot.datasets import IMAGES_ORDER
from lerobot.configs.datasets import IMAGES_ORDER
from lerobot.policies.normalize import (
Normalize,
Unnormalize,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
from lerobot.policies.smolvla2.smolvlm_with_expert2 import SmolVLMWithExpertModel
from lerobot.policies.smolvla2.configuration_smolvla2 import SmolVLA2Config
from lerobot.policies.utils import (
populate_queues,
@@ -389,10 +389,41 @@ class SmolVLA2Policy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
return actions
def merge_peft_model_weights(self) -> None:
if "lora" in self.config.peft_method:
self.model.vlm_with_expert.merge_lora_weights()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
return actions
@torch.no_grad
def select_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
@@ -691,7 +722,6 @@ class VLAFlowMatching(nn.Module):
model_id=self.config.vlm_model_name,
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
load_vlm_weights=self.config.load_vlm_weights,
attention_mode=self.config.attention_mode,
num_expert_layers=self.config.num_expert_layers,
@@ -24,6 +24,7 @@ from transformers import (
AutoProcessor,
SmolVLMForConditionalGeneration,
)
from peft import LoraConfig, TaskType, get_peft_model
def apply_rope(x, positions, max_wavelength=10_000):
+8 -1
View File
@@ -175,11 +175,18 @@ def train(cfg: TrainPipelineConfig):
else:
shuffle = True
sampler = None
keys_to_max_dim = getattr(dataset.meta, "keys_to_max_dim", {})
keys_to_max_dim = {
"action": (32,),
"observation.state": (32,),
"observation.image": (3, 1080, 1920),
"observation.image2": (3, 1080, 1920),
}
collate_fn = partial(multidataset_collate_fn, keys_to_max_dim=keys_to_max_dim)
dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=collate_fn,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,