mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
remove
This commit is contained in:
@@ -78,7 +78,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype="bfloat16",
|
||||
# torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
|
||||
@@ -465,48 +465,6 @@ def _compile_episode_data(
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
|
||||
def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
"""Recreate normalization layers with proper stats from the dataset."""
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
# Convert numpy stats to the format expected by normalization layers
|
||||
stats = {}
|
||||
for key, stat_dict in dataset_meta.stats.items():
|
||||
stats[key] = {
|
||||
stat_type: torch.from_numpy(stat_array) if isinstance(stat_array, np.ndarray) else stat_array
|
||||
for stat_type, stat_array in stat_dict.items()
|
||||
}
|
||||
|
||||
# Recreate normalization layers with proper stats
|
||||
normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
unnormalize_outputs = Unnormalize(
|
||||
policy.config.output_features, policy.config.normalization_mapping, stats
|
||||
)
|
||||
|
||||
# Replace the normalization layers on the policy
|
||||
policy.normalize_inputs = normalize_inputs
|
||||
policy.normalize_targets = normalize_targets
|
||||
policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
print("Normalization layers recreated with dataset stats.")
|
||||
|
||||
|
||||
def load_smolvla(cfg, dataset_repo: str, policy):
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset = LeRobotDataset(dataset_repo, root="/raid/jade/.cache/huggingface/datasets/")
|
||||
_inject_normalization_stats(policy=policy, dataset_meta=dataset.meta) # only needed if stats are missing
|
||||
return policy.to("cuda"), dataset
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
@@ -105,35 +105,6 @@ def update_policy(
|
||||
train_metrics.update_s = time.perf_counter() - start_time
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
# def _inject_normalization_stats(policy: SmolVLAPolicy, dataset_meta: LeRobotDatasetMetadata):
|
||||
# """Recreate normalization layers with dataset stats if missing (Adil's workaround)."""
|
||||
# from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
# if not hasattr(dataset_meta, "stats") or not dataset_meta.stats:
|
||||
# print("⚠️ Dataset has no stats, skipping normalization injection.")
|
||||
# return
|
||||
|
||||
# stats = {}
|
||||
# for key, stat_dict in dataset_meta.stats.items():
|
||||
# stats[key] = {
|
||||
# stat_type: torch.as_tensor(stat_array)
|
||||
# if isinstance(stat_array, np.ndarray)
|
||||
# else stat_array
|
||||
# for stat_type, stat_array in stat_dict.items()
|
||||
# }
|
||||
|
||||
# normalize_inputs = Normalize(policy.config.input_features, policy.config.normalization_mapping, stats)
|
||||
# normalize_targets = Normalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
# unnormalize_outputs = Unnormalize(policy.config.output_features, policy.config.normalization_mapping, stats)
|
||||
|
||||
# policy.normalize_inputs = normalize_inputs
|
||||
# policy.normalize_targets = normalize_targets
|
||||
# policy.unnormalize_outputs = unnormalize_outputs
|
||||
|
||||
# print("✅ Normalization layers injected with dataset stats.")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
cfg.validate()
|
||||
|
||||
Reference in New Issue
Block a user