validate molmoact2 gripper range

This commit is contained in:
hq-fang
2026-05-22 22:12:24 +00:00
parent dca792951e
commit 36d0ba5127
2 changed files with 169 additions and 14 deletions
@@ -431,6 +431,38 @@ def _feature_dim(stats: dict[str, Any] | None) -> int | None:
return None
def _stats_array(value: Any) -> np.ndarray | None:
if value is None:
return None
if torch.is_tensor(value):
return value.detach().cpu().numpy() if value.ndim > 0 else None
arr = np.asarray(value)
return arr if arr.ndim > 0 else None
def _validate_masked_passthrough_stats(feature_stats: dict[str, Any], mask: list[bool], key: str) -> None:
min_values = _stats_array(feature_stats.get("min"))
max_values = _stats_array(feature_stats.get("max"))
if min_values is None or max_values is None:
return
mask_array = np.asarray(mask, dtype=bool)
if (
mask_array.ndim != 1
or min_values.shape[-1] != mask_array.shape[0]
or max_values.shape[-1] != mask_array.shape[0]
or not bool((~mask_array).any())
):
return
passthrough_min = min_values[..., ~mask_array]
passthrough_max = max_values[..., ~mask_array]
if bool(((passthrough_min < -1.0) | (passthrough_max > 1.0)).any()):
raise ValueError(
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
)
def _feature_names_from_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None:
if dataset_meta is None:
return None
@@ -501,11 +533,49 @@ def _add_gripper_masks_to_stats(
continue
if len(names) != dim:
continue
feature_stats["mask"] = ["gripper" not in name.lower() for name in names]
mask = ["gripper" not in name.lower() for name in names]
_validate_masked_passthrough_stats(feature_stats, mask, key)
feature_stats["mask"] = mask
return stats
def _normalization_masks_from_stats(
dataset_stats: dict[str, dict[str, Any]] | None,
) -> dict[str, list[bool]]:
masks: dict[str, list[bool]] = {}
for key in (ACTION, OBS_STATE):
feature_stats = (dataset_stats or {}).get(key)
if not isinstance(feature_stats, dict):
continue
mask = feature_stats.get("mask")
if isinstance(mask, Tensor):
mask = mask.detach().cpu().tolist()
if isinstance(mask, list) and all(isinstance(value, bool) for value in mask):
masks[key] = mask
return masks
class _MolmoAct2MaskedNormalizationMixin:
@staticmethod
def _broadcast_feature_mask(mask: Tensor, tensor: Tensor) -> Tensor | None:
mask = mask.to(device=tensor.device, dtype=torch.bool)
if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]:
return None
while mask.ndim < tensor.ndim:
mask = mask.unsqueeze(0)
return mask
@staticmethod
def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None:
passthrough_mask = ~mask.expand_as(tensor)
if not bool(passthrough_mask.any()):
return
passthrough_values = tensor[passthrough_mask]
if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()):
raise ValueError(
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
)
def _apply_transform(
self, tensor: Tensor, key: str, feature_type: Any, *, inverse: bool = False
) -> Tensor:
@@ -514,11 +584,11 @@ class _MolmoAct2MaskedNormalizationMixin:
mask = stats.get("mask") if isinstance(stats, dict) else None
if mask is None:
return transformed
mask = mask.to(device=tensor.device, dtype=torch.bool)
if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]:
mask = self._broadcast_feature_mask(mask, tensor)
if mask is None:
return transformed
while mask.ndim < tensor.ndim:
mask = mask.unsqueeze(0)
if not inverse:
self._validate_masked_passthrough_range(tensor, mask, key)
return torch.where(mask, transformed, tensor)
@@ -539,16 +609,48 @@ class MolmoAct2MaskedUnnormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixi
class MolmoAct2ClampNormalizedProcessorStep(ProcessorStep):
"""Clamp q01/q99-normalized state and action to the range used by the old trainer."""
normalization_masks: dict[str, list[bool]] | None = None
@staticmethod
def _broadcast_feature_mask(mask: list[bool], tensor: Tensor) -> Tensor | None:
tensor_mask = torch.tensor(mask, device=tensor.device, dtype=torch.bool)
if tensor_mask.ndim != 1 or tensor.shape[-1] != tensor_mask.shape[0]:
return None
while tensor_mask.ndim < tensor.ndim:
tensor_mask = tensor_mask.unsqueeze(0)
return tensor_mask
@staticmethod
def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None:
passthrough_mask = ~mask.expand_as(tensor)
if not bool(passthrough_mask.any()):
return
passthrough_values = tensor[passthrough_mask]
if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()):
raise ValueError(
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
)
def _clamp_tensor(self, tensor: Tensor, key: str) -> Tensor:
mask = (self.normalization_masks or {}).get(key)
if mask is None:
return tensor.clamp(-1.0, 1.0)
tensor_mask = self._broadcast_feature_mask(mask, tensor)
if tensor_mask is None:
return tensor.clamp(-1.0, 1.0)
self._validate_masked_passthrough_range(tensor, tensor_mask, key)
return torch.where(tensor_mask, tensor.clamp(-1.0, 1.0), tensor)
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
observation = transition.get(TransitionKey.OBSERVATION)
if isinstance(observation, dict) and OBS_STATE in observation:
observation = observation.copy()
observation[OBS_STATE] = torch.as_tensor(observation[OBS_STATE]).clamp(-1.0, 1.0)
observation[OBS_STATE] = self._clamp_tensor(torch.as_tensor(observation[OBS_STATE]), OBS_STATE)
transition[TransitionKey.OBSERVATION] = observation
action = transition.get(TransitionKey.ACTION)
if action is not None:
transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0)
transition[TransitionKey.ACTION] = self._clamp_tensor(torch.as_tensor(action), ACTION)
return transition
def transform_features(
@@ -924,6 +1026,7 @@ def make_molmoact2_pre_post_processors(
normalize_gripper=config.normalize_gripper,
dataset_feature_names=config.dataset_feature_names,
)
normalization_masks = _normalization_masks_from_stats(masked_dataset_stats)
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
@@ -933,7 +1036,7 @@ def make_molmoact2_pre_post_processors(
norm_map=config.normalization_mapping,
stats=masked_dataset_stats,
),
MolmoAct2ClampNormalizedProcessorStep(),
MolmoAct2ClampNormalizedProcessorStep(normalization_masks=normalization_masks),
MolmoAct2PackInputsProcessorStep(
checkpoint_path=config.checkpoint_path,
checkpoint_revision=config.checkpoint_revision,
+58 -6
View File
@@ -46,6 +46,7 @@ from lerobot.policies.molmoact2.configuration_molmoact2 import (
)
from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy
from lerobot.policies.molmoact2.processor_molmoact2 import (
MolmoAct2ClampNormalizedProcessorStep,
MolmoAct2MaskedNormalizerProcessorStep,
MolmoAct2MaskedUnnormalizerProcessorStep,
MolmoAct2PackInputsProcessorStep,
@@ -185,8 +186,8 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path):
FeatureType.STATE: NormalizationMode.QUANTILES,
}
transition = {
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])},
TransitionKey.ACTION: torch.tensor([[5.0, 7.0]]),
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 0.7]])},
TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]),
}
normalizer = MolmoAct2MaskedNormalizerProcessorStep(
features=features,
@@ -195,17 +196,68 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path):
)
normalized = normalizer(transition)
assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 7.0]]))
assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, 7.0]]))
assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 0.7]]))
assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, -0.7]]))
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
normalizer(
{
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])},
TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]),
}
)
unnormalizer = MolmoAct2MaskedUnnormalizerProcessorStep(
features={ACTION: features[ACTION]},
norm_map=norm_map,
stats=masked_stats,
)
unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, 7.0]])})
unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, -0.7]])})
assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, 7.0]]))
assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, -0.7]]))
def test_molmoact2_gripper_mask_validates_dataset_stats(tmp_path):
meta_dir = tmp_path / "meta"
meta_dir.mkdir()
(meta_dir / "info.json").write_text(
json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}),
encoding="utf-8",
)
stats = {
ACTION: {
"min": [-0.5, -2.0],
"max": [0.5, 0.5],
}
}
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
_add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=False)
masked_stats = _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=True)
assert masked_stats is not None
assert masked_stats[ACTION]["mask"] == [True, True]
def test_molmoact2_clamp_normalized_respects_masked_gripper_dims():
step = MolmoAct2ClampNormalizedProcessorStep(
normalization_masks={
ACTION: [True, False],
OBS_STATE: [True, False],
}
)
transition = {
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[-2.0, 0.8]])},
TransitionKey.ACTION: torch.tensor([[2.0, -0.8]]),
}
clamped = step(transition)
assert torch.equal(clamped[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[-1.0, 0.8]]))
assert torch.equal(clamped[TransitionKey.ACTION], torch.tensor([[1.0, -0.8]]))
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
step({TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[0.0, 1.2]])}})
def test_molmoact2_normalize_gripper_true_keeps_all_dims_normalized(tmp_path):