mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 00:07:03 +00:00
validate molmoact2 gripper range
This commit is contained in:
@@ -431,6 +431,38 @@ def _feature_dim(stats: dict[str, Any] | None) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def _stats_array(value: Any) -> np.ndarray | None:
|
||||
if value is None:
|
||||
return None
|
||||
if torch.is_tensor(value):
|
||||
return value.detach().cpu().numpy() if value.ndim > 0 else None
|
||||
arr = np.asarray(value)
|
||||
return arr if arr.ndim > 0 else None
|
||||
|
||||
|
||||
def _validate_masked_passthrough_stats(feature_stats: dict[str, Any], mask: list[bool], key: str) -> None:
|
||||
min_values = _stats_array(feature_stats.get("min"))
|
||||
max_values = _stats_array(feature_stats.get("max"))
|
||||
if min_values is None or max_values is None:
|
||||
return
|
||||
|
||||
mask_array = np.asarray(mask, dtype=bool)
|
||||
if (
|
||||
mask_array.ndim != 1
|
||||
or min_values.shape[-1] != mask_array.shape[0]
|
||||
or max_values.shape[-1] != mask_array.shape[0]
|
||||
or not bool((~mask_array).any())
|
||||
):
|
||||
return
|
||||
|
||||
passthrough_min = min_values[..., ~mask_array]
|
||||
passthrough_max = max_values[..., ~mask_array]
|
||||
if bool(((passthrough_min < -1.0) | (passthrough_max > 1.0)).any()):
|
||||
raise ValueError(
|
||||
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
|
||||
)
|
||||
|
||||
|
||||
def _feature_names_from_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None:
|
||||
if dataset_meta is None:
|
||||
return None
|
||||
@@ -501,11 +533,49 @@ def _add_gripper_masks_to_stats(
|
||||
continue
|
||||
if len(names) != dim:
|
||||
continue
|
||||
feature_stats["mask"] = ["gripper" not in name.lower() for name in names]
|
||||
mask = ["gripper" not in name.lower() for name in names]
|
||||
_validate_masked_passthrough_stats(feature_stats, mask, key)
|
||||
feature_stats["mask"] = mask
|
||||
return stats
|
||||
|
||||
|
||||
def _normalization_masks_from_stats(
|
||||
dataset_stats: dict[str, dict[str, Any]] | None,
|
||||
) -> dict[str, list[bool]]:
|
||||
masks: dict[str, list[bool]] = {}
|
||||
for key in (ACTION, OBS_STATE):
|
||||
feature_stats = (dataset_stats or {}).get(key)
|
||||
if not isinstance(feature_stats, dict):
|
||||
continue
|
||||
mask = feature_stats.get("mask")
|
||||
if isinstance(mask, Tensor):
|
||||
mask = mask.detach().cpu().tolist()
|
||||
if isinstance(mask, list) and all(isinstance(value, bool) for value in mask):
|
||||
masks[key] = mask
|
||||
return masks
|
||||
|
||||
|
||||
class _MolmoAct2MaskedNormalizationMixin:
|
||||
@staticmethod
|
||||
def _broadcast_feature_mask(mask: Tensor, tensor: Tensor) -> Tensor | None:
|
||||
mask = mask.to(device=tensor.device, dtype=torch.bool)
|
||||
if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]:
|
||||
return None
|
||||
while mask.ndim < tensor.ndim:
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None:
|
||||
passthrough_mask = ~mask.expand_as(tensor)
|
||||
if not bool(passthrough_mask.any()):
|
||||
return
|
||||
passthrough_values = tensor[passthrough_mask]
|
||||
if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()):
|
||||
raise ValueError(
|
||||
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
|
||||
)
|
||||
|
||||
def _apply_transform(
|
||||
self, tensor: Tensor, key: str, feature_type: Any, *, inverse: bool = False
|
||||
) -> Tensor:
|
||||
@@ -514,11 +584,11 @@ class _MolmoAct2MaskedNormalizationMixin:
|
||||
mask = stats.get("mask") if isinstance(stats, dict) else None
|
||||
if mask is None:
|
||||
return transformed
|
||||
mask = mask.to(device=tensor.device, dtype=torch.bool)
|
||||
if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]:
|
||||
mask = self._broadcast_feature_mask(mask, tensor)
|
||||
if mask is None:
|
||||
return transformed
|
||||
while mask.ndim < tensor.ndim:
|
||||
mask = mask.unsqueeze(0)
|
||||
if not inverse:
|
||||
self._validate_masked_passthrough_range(tensor, mask, key)
|
||||
return torch.where(mask, transformed, tensor)
|
||||
|
||||
|
||||
@@ -539,16 +609,48 @@ class MolmoAct2MaskedUnnormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixi
|
||||
class MolmoAct2ClampNormalizedProcessorStep(ProcessorStep):
|
||||
"""Clamp q01/q99-normalized state and action to the range used by the old trainer."""
|
||||
|
||||
normalization_masks: dict[str, list[bool]] | None = None
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_feature_mask(mask: list[bool], tensor: Tensor) -> Tensor | None:
|
||||
tensor_mask = torch.tensor(mask, device=tensor.device, dtype=torch.bool)
|
||||
if tensor_mask.ndim != 1 or tensor.shape[-1] != tensor_mask.shape[0]:
|
||||
return None
|
||||
while tensor_mask.ndim < tensor.ndim:
|
||||
tensor_mask = tensor_mask.unsqueeze(0)
|
||||
return tensor_mask
|
||||
|
||||
@staticmethod
|
||||
def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None:
|
||||
passthrough_mask = ~mask.expand_as(tensor)
|
||||
if not bool(passthrough_mask.any()):
|
||||
return
|
||||
passthrough_values = tensor[passthrough_mask]
|
||||
if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()):
|
||||
raise ValueError(
|
||||
f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True."
|
||||
)
|
||||
|
||||
def _clamp_tensor(self, tensor: Tensor, key: str) -> Tensor:
|
||||
mask = (self.normalization_masks or {}).get(key)
|
||||
if mask is None:
|
||||
return tensor.clamp(-1.0, 1.0)
|
||||
tensor_mask = self._broadcast_feature_mask(mask, tensor)
|
||||
if tensor_mask is None:
|
||||
return tensor.clamp(-1.0, 1.0)
|
||||
self._validate_masked_passthrough_range(tensor, tensor_mask, key)
|
||||
return torch.where(tensor_mask, tensor.clamp(-1.0, 1.0), tensor)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if isinstance(observation, dict) and OBS_STATE in observation:
|
||||
observation = observation.copy()
|
||||
observation[OBS_STATE] = torch.as_tensor(observation[OBS_STATE]).clamp(-1.0, 1.0)
|
||||
observation[OBS_STATE] = self._clamp_tensor(torch.as_tensor(observation[OBS_STATE]), OBS_STATE)
|
||||
transition[TransitionKey.OBSERVATION] = observation
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None:
|
||||
transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0)
|
||||
transition[TransitionKey.ACTION] = self._clamp_tensor(torch.as_tensor(action), ACTION)
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
@@ -924,6 +1026,7 @@ def make_molmoact2_pre_post_processors(
|
||||
normalize_gripper=config.normalize_gripper,
|
||||
dataset_feature_names=config.dataset_feature_names,
|
||||
)
|
||||
normalization_masks = _normalization_masks_from_stats(masked_dataset_stats)
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
@@ -933,7 +1036,7 @@ def make_molmoact2_pre_post_processors(
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=masked_dataset_stats,
|
||||
),
|
||||
MolmoAct2ClampNormalizedProcessorStep(),
|
||||
MolmoAct2ClampNormalizedProcessorStep(normalization_masks=normalization_masks),
|
||||
MolmoAct2PackInputsProcessorStep(
|
||||
checkpoint_path=config.checkpoint_path,
|
||||
checkpoint_revision=config.checkpoint_revision,
|
||||
|
||||
@@ -46,6 +46,7 @@ from lerobot.policies.molmoact2.configuration_molmoact2 import (
|
||||
)
|
||||
from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy
|
||||
from lerobot.policies.molmoact2.processor_molmoact2 import (
|
||||
MolmoAct2ClampNormalizedProcessorStep,
|
||||
MolmoAct2MaskedNormalizerProcessorStep,
|
||||
MolmoAct2MaskedUnnormalizerProcessorStep,
|
||||
MolmoAct2PackInputsProcessorStep,
|
||||
@@ -185,8 +186,8 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path):
|
||||
FeatureType.STATE: NormalizationMode.QUANTILES,
|
||||
}
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])},
|
||||
TransitionKey.ACTION: torch.tensor([[5.0, 7.0]]),
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 0.7]])},
|
||||
TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]),
|
||||
}
|
||||
normalizer = MolmoAct2MaskedNormalizerProcessorStep(
|
||||
features=features,
|
||||
@@ -195,17 +196,68 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path):
|
||||
)
|
||||
normalized = normalizer(transition)
|
||||
|
||||
assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 7.0]]))
|
||||
assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, 7.0]]))
|
||||
assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 0.7]]))
|
||||
assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, -0.7]]))
|
||||
|
||||
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
|
||||
normalizer(
|
||||
{
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])},
|
||||
TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]),
|
||||
}
|
||||
)
|
||||
|
||||
unnormalizer = MolmoAct2MaskedUnnormalizerProcessorStep(
|
||||
features={ACTION: features[ACTION]},
|
||||
norm_map=norm_map,
|
||||
stats=masked_stats,
|
||||
)
|
||||
unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, 7.0]])})
|
||||
unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, -0.7]])})
|
||||
|
||||
assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, 7.0]]))
|
||||
assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, -0.7]]))
|
||||
|
||||
|
||||
def test_molmoact2_gripper_mask_validates_dataset_stats(tmp_path):
|
||||
meta_dir = tmp_path / "meta"
|
||||
meta_dir.mkdir()
|
||||
(meta_dir / "info.json").write_text(
|
||||
json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
stats = {
|
||||
ACTION: {
|
||||
"min": [-0.5, -2.0],
|
||||
"max": [0.5, 0.5],
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
|
||||
_add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=False)
|
||||
|
||||
masked_stats = _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=True)
|
||||
assert masked_stats is not None
|
||||
assert masked_stats[ACTION]["mask"] == [True, True]
|
||||
|
||||
|
||||
def test_molmoact2_clamp_normalized_respects_masked_gripper_dims():
|
||||
step = MolmoAct2ClampNormalizedProcessorStep(
|
||||
normalization_masks={
|
||||
ACTION: [True, False],
|
||||
OBS_STATE: [True, False],
|
||||
}
|
||||
)
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[-2.0, 0.8]])},
|
||||
TransitionKey.ACTION: torch.tensor([[2.0, -0.8]]),
|
||||
}
|
||||
|
||||
clamped = step(transition)
|
||||
|
||||
assert torch.equal(clamped[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[-1.0, 0.8]]))
|
||||
assert torch.equal(clamped[TransitionKey.ACTION], torch.tensor([[1.0, -0.8]]))
|
||||
|
||||
with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"):
|
||||
step({TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[0.0, 1.2]])}})
|
||||
|
||||
|
||||
def test_molmoact2_normalize_gripper_true_keeps_all_dims_normalized(tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user