This commit is contained in:
Jade Choghari
2025-09-11 14:28:46 +02:00
parent f2530570e0
commit d09b2a28af
3 changed files with 0 additions and 72 deletions
@@ -78,7 +78,6 @@ class SmolVLMWithExpertModel(nn.Module):
model_id,
device_map="auto",
torch_dtype="bfloat16",
# torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
config = self.vlm.config
-42
View File
@@ -465,48 +465,6 @@ def _compile_episode_data(
return data_dict
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
"""Recreate normalization layers with proper stats from the dataset."""
from lerobot.policies.normalize import Normalize, Unnormalize
# Convert numpy stats to the format expected by normalization layers
stats = {}
for key, stat_dict in dataset_meta.stats.items():
stats[key] = {
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
for stat_type, stat_array in stat_dict.items()
}
# Recreate normalization layers with proper stats
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
unnormalize_outputs = Unnormalize(
policy.config.output_features, policy.config.normalization_mapping, stats
)
# Replace the normalization layers on the policy
policy.normalize_inputs = normalize_inputs
policy.normalize_targets = normalize_targets
policy.unnormalize_outputs = unnormalize_outputs
print("Normalization layers recreated with dataset stats.")
def load_smolvla(cfg, dataset_repo: str, policy):
from lerobot.datasets.lerobot_dataset import LeRobotDataset
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
return policy.to("cuda"), dataset
@parser.wrap()
def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
-29
View File
@@ -105,35 +105,6 @@ def update_policy(
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
# from lerobot.policies.normalize import Normalize, Unnormalize
# if not hasattr(dataset_meta, "stats") or not dataset_meta.stats:
# print("⚠️ Dataset has no stats, skipping normalization injection.")
# return
# stats = {}
# for key, stat_dict in dataset_meta.stats.items():
# stats[key] = {
# stat_type: torch.as_tensor(stat_array)
# if isinstance(stat_array, np.ndarray)
# else stat_array
# for stat_type, stat_array in stat_dict.items()
# }
# normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
# normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
# unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats)
# policy.normalize_inputs = normalize_inputs
# policy.normalize_targets = normalize_targets
# policy.unnormalize_outputs = unnormalize_outputs
# print("✅ Normalization layers injected with dataset stats.")
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()