From d09b2a28af6879bd118325f300a31685f2482198 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Thu, 11 Sep 2025 14:28:46 +0200 Subject: [PATCH] remove --- .../policies/smolvla/smolvlm_with_expert.py | 1 - src/lerobot/scripts/eval.py | 42 ------------------- src/lerobot/scripts/train.py | 29 ------------- 3 files changed, 72 deletions(-) diff --git a/src/lerobot/policies/smolvla/smolvlm_with_expert.py b/src/lerobot/policies/smolvla/smolvlm_with_expert.py index 3b78c99e6..f3d1a693a 100644 --- a/src/lerobot/policies/smolvla/smolvlm_with_expert.py +++ b/src/lerobot/policies/smolvla/smolvlm_with_expert.py @@ -78,7 +78,6 @@ class SmolVLMWithExpertModel(nn.Module): model_id, device_map="auto", torch_dtype="bfloat16", - # torch_dtype=torch.float16, low_cpu_mem_usage=True, ) config = self.vlm.config diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 2bc928117..5e8d63f09 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -465,48 +465,6 @@ def _compile_episode_data( return data_dict - -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy - - -def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): - """Recreate normalization layers with proper stats from the dataset.""" - from lerobot.policies.normalize import Normalize, Unnormalize - - # Convert numpy stats to the format expected by normalization layers - stats = {} - for key, stat_dict in dataset_meta.stats.items(): - stats[key] = { - stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array - for stat_type, stat_array in stat_dict.items() - } - - # Recreate normalization layers with proper stats - normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) - - normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) - - unnormalize_outputs = Unnormalize( - policy.config.output_features, policy.config.normalization_mapping, stats - ) - - # Replace the normalization layers on the policy - policy.normalize_inputs = normalize_inputs - policy.normalize_targets = normalize_targets - policy.unnormalize_outputs = unnormalize_outputs - - print("Normalization layers recreated with dataset stats.") - - -def load_smolvla(cfg, dataset_repo: str, policy): - from lerobot.datasets.lerobot_dataset import LeRobotDataset - - dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/") - _inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing - return policy.to("cuda"), dataset - - @parser.wrap() def eval_main(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg))) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index fc8be4ebc..3aea697d0 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -105,35 +105,6 @@ def update_policy( train_metrics.update_s = time.perf_counter() - start_time return train_metrics, output_dict - -# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata): -# """Recreate normalization layers with dataset stats if missing (Adil's workaround).""" -# from lerobot.policies.normalize import Normalize, Unnormalize - -# if not hasattr(dataset_meta, "stats") or not dataset_meta.stats: -# print("⚠️ Dataset has no stats, skipping normalization injection.") -# return - -# stats = {} -# for key, stat_dict in dataset_meta.stats.items(): -# stats[key] = { -# stat_type: torch.as_tensor(stat_array) -# if isinstance(stat_array, np.ndarray) -# else stat_array -# for stat_type, stat_array in stat_dict.items() -# } - -# normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats) -# normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats) -# unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats) - -# policy.normalize_inputs = normalize_inputs -# policy.normalize_targets = normalize_targets -# policy.unnormalize_outputs = unnormalize_outputs - -# print("✅ Normalization layers injected with dataset stats.") - - @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate()