mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
fix(make_policy): rename mapping edge cases in training (#2332)
* fix bug * update fixes * add hf license * more fixes * add transformers * iterate on review * more fixes * more fixes * add a False test * reduce img size * reduce img size * skip the test * add * add style
This commit is contained in:
@@ -38,6 +38,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
@@ -420,20 +421,7 @@ def make_policy(
|
||||
# policy = torch.compile(policy, mode="reduce-overhead")
|
||||
|
||||
if not rename_map:
|
||||
expected_features = set(cfg.input_features.keys()) | set(cfg.output_features.keys())
|
||||
provided_features = set(features.keys())
|
||||
if expected_features and provided_features != expected_features:
|
||||
missing = expected_features - provided_features
|
||||
extra = provided_features - expected_features
|
||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||
raise ValueError(
|
||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||
f"use the `--rename_map` argument, for example:\n"
|
||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||
)
|
||||
validate_visual_features_consistency(cfg, features)
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -22,6 +22,8 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.datasets.utils import build_dataset_frame
|
||||
from lerobot.processor import PolicyAction, RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
@@ -198,3 +200,42 @@ def make_robot_action(action_tensor: PolicyAction, ds_features: dict[str, dict])
|
||||
f"{name}": float(action_tensor[i]) for i, name in enumerate(action_names)
|
||||
}
|
||||
return act_processed_policy
|
||||
|
||||
|
||||
def raise_feature_mismatch_error(
|
||||
provided_features: set[str],
|
||||
expected_features: set[str],
|
||||
) -> None:
|
||||
"""
|
||||
Raises a standardized ValueError for feature mismatches between dataset/environment and policy config.
|
||||
"""
|
||||
missing = expected_features - provided_features
|
||||
extra = provided_features - expected_features
|
||||
# TODO (jadechoghari): provide a dynamic rename map suggestion to the user.
|
||||
raise ValueError(
|
||||
f"Feature mismatch between dataset/environment and policy config.\n"
|
||||
f"- Missing features: {sorted(missing) if missing else 'None'}\n"
|
||||
f"- Extra features: {sorted(extra) if extra else 'None'}\n\n"
|
||||
f"Please ensure your dataset and policy use consistent feature names.\n"
|
||||
f"If your dataset uses different observation keys (e.g., cameras named differently), "
|
||||
f"use the `--rename_map` argument, for example:\n"
|
||||
f' --rename_map=\'{{"observation.images.left": "observation.images.camera1", '
|
||||
f'"observation.images.top": "observation.images.camera2"}}\''
|
||||
)
|
||||
|
||||
|
||||
def validate_visual_features_consistency(
|
||||
cfg: PreTrainedConfig,
|
||||
features: dict[str, PolicyFeature],
|
||||
) -> None:
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
features (Dict[str, PolicyFeature]): A mapping of feature names to PolicyFeature objects.
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
Reference in New Issue
Block a user