diff --git a/src/lerobot/policies/molmoact2/processor_molmoact2.py b/src/lerobot/policies/molmoact2/processor_molmoact2.py index b393dd440..6c7a3ed5c 100644 --- a/src/lerobot/policies/molmoact2/processor_molmoact2.py +++ b/src/lerobot/policies/molmoact2/processor_molmoact2.py @@ -431,6 +431,38 @@ def _feature_dim(stats: dict[str, Any] | None) -> int | None: return None +def _stats_array(value: Any) -> np.ndarray | None: + if value is None: + return None + if torch.is_tensor(value): + return value.detach().cpu().numpy() if value.ndim > 0 else None + arr = np.asarray(value) + return arr if arr.ndim > 0 else None + + +def _validate_masked_passthrough_stats(feature_stats: dict[str, Any], mask: list[bool], key: str) -> None: + min_values = _stats_array(feature_stats.get("min")) + max_values = _stats_array(feature_stats.get("max")) + if min_values is None or max_values is None: + return + + mask_array = np.asarray(mask, dtype=bool) + if ( + mask_array.ndim != 1 + or min_values.shape[-1] != mask_array.shape[0] + or max_values.shape[-1] != mask_array.shape[0] + or not bool((~mask_array).any()) + ): + return + + passthrough_min = min_values[..., ~mask_array] + passthrough_max = max_values[..., ~mask_array] + if bool(((passthrough_min < -1.0) | (passthrough_max > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + + def _feature_names_from_meta(dataset_meta: Any | None, feature_key: str) -> list[str] | None: if dataset_meta is None: return None @@ -501,11 +533,49 @@ def _add_gripper_masks_to_stats( continue if len(names) != dim: continue - feature_stats["mask"] = ["gripper" not in name.lower() for name in names] + mask = ["gripper" not in name.lower() for name in names] + _validate_masked_passthrough_stats(feature_stats, mask, key) + feature_stats["mask"] = mask return stats +def _normalization_masks_from_stats( + dataset_stats: dict[str, dict[str, Any]] | None, +) -> dict[str, list[bool]]: + masks: dict[str, list[bool]] = {} + for key in (ACTION, OBS_STATE): + feature_stats = (dataset_stats or {}).get(key) + if not isinstance(feature_stats, dict): + continue + mask = feature_stats.get("mask") + if isinstance(mask, Tensor): + mask = mask.detach().cpu().tolist() + if isinstance(mask, list) and all(isinstance(value, bool) for value in mask): + masks[key] = mask + return masks + + class _MolmoAct2MaskedNormalizationMixin: + @staticmethod + def _broadcast_feature_mask(mask: Tensor, tensor: Tensor) -> Tensor | None: + mask = mask.to(device=tensor.device, dtype=torch.bool) + if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]: + return None + while mask.ndim < tensor.ndim: + mask = mask.unsqueeze(0) + return mask + + @staticmethod + def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None: + passthrough_mask = ~mask.expand_as(tensor) + if not bool(passthrough_mask.any()): + return + passthrough_values = tensor[passthrough_mask] + if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + def _apply_transform( self, tensor: Tensor, key: str, feature_type: Any, *, inverse: bool = False ) -> Tensor: @@ -514,11 +584,11 @@ class _MolmoAct2MaskedNormalizationMixin: mask = stats.get("mask") if isinstance(stats, dict) else None if mask is None: return transformed - mask = mask.to(device=tensor.device, dtype=torch.bool) - if mask.ndim != 1 or tensor.shape[-1] != mask.shape[0]: + mask = self._broadcast_feature_mask(mask, tensor) + if mask is None: return transformed - while mask.ndim < tensor.ndim: - mask = mask.unsqueeze(0) + if not inverse: + self._validate_masked_passthrough_range(tensor, mask, key) return torch.where(mask, transformed, tensor) @@ -539,16 +609,48 @@ class MolmoAct2MaskedUnnormalizerProcessorStep(_MolmoAct2MaskedNormalizationMixi class MolmoAct2ClampNormalizedProcessorStep(ProcessorStep): """Clamp q01/q99-normalized state and action to the range used by the old trainer.""" + normalization_masks: dict[str, list[bool]] | None = None + + @staticmethod + def _broadcast_feature_mask(mask: list[bool], tensor: Tensor) -> Tensor | None: + tensor_mask = torch.tensor(mask, device=tensor.device, dtype=torch.bool) + if tensor_mask.ndim != 1 or tensor.shape[-1] != tensor_mask.shape[0]: + return None + while tensor_mask.ndim < tensor.ndim: + tensor_mask = tensor_mask.unsqueeze(0) + return tensor_mask + + @staticmethod + def _validate_masked_passthrough_range(tensor: Tensor, mask: Tensor, key: str) -> None: + passthrough_mask = ~mask.expand_as(tensor) + if not bool(passthrough_mask.any()): + return + passthrough_values = tensor[passthrough_mask] + if bool(((passthrough_values < -1.0) | (passthrough_values > 1.0)).any()): + raise ValueError( + f"MolmoAct2 {key} gripper values are not under [-1, 1]. Please set normalize_gripper=True." + ) + + def _clamp_tensor(self, tensor: Tensor, key: str) -> Tensor: + mask = (self.normalization_masks or {}).get(key) + if mask is None: + return tensor.clamp(-1.0, 1.0) + tensor_mask = self._broadcast_feature_mask(mask, tensor) + if tensor_mask is None: + return tensor.clamp(-1.0, 1.0) + self._validate_masked_passthrough_range(tensor, tensor_mask, key) + return torch.where(tensor_mask, tensor.clamp(-1.0, 1.0), tensor) + def __call__(self, transition: EnvTransition) -> EnvTransition: transition = transition.copy() observation = transition.get(TransitionKey.OBSERVATION) if isinstance(observation, dict) and OBS_STATE in observation: observation = observation.copy() - observation[OBS_STATE] = torch.as_tensor(observation[OBS_STATE]).clamp(-1.0, 1.0) + observation[OBS_STATE] = self._clamp_tensor(torch.as_tensor(observation[OBS_STATE]), OBS_STATE) transition[TransitionKey.OBSERVATION] = observation action = transition.get(TransitionKey.ACTION) if action is not None: - transition[TransitionKey.ACTION] = torch.as_tensor(action).clamp(-1.0, 1.0) + transition[TransitionKey.ACTION] = self._clamp_tensor(torch.as_tensor(action), ACTION) return transition def transform_features( @@ -924,6 +1026,7 @@ def make_molmoact2_pre_post_processors( normalize_gripper=config.normalize_gripper, dataset_feature_names=config.dataset_feature_names, ) + normalization_masks = _normalization_masks_from_stats(masked_dataset_stats) input_steps: list[ProcessorStep] = [ RenameObservationsProcessorStep(rename_map={}), @@ -933,7 +1036,7 @@ def make_molmoact2_pre_post_processors( norm_map=config.normalization_mapping, stats=masked_dataset_stats, ), - MolmoAct2ClampNormalizedProcessorStep(), + MolmoAct2ClampNormalizedProcessorStep(normalization_masks=normalization_masks), MolmoAct2PackInputsProcessorStep( checkpoint_path=config.checkpoint_path, checkpoint_revision=config.checkpoint_revision, diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py index b2c6cbfca..3631bcc9b 100644 --- a/tests/policies/molmoact2/test_molmoact2.py +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -46,6 +46,7 @@ from lerobot.policies.molmoact2.configuration_molmoact2 import ( ) from lerobot.policies.molmoact2.modeling_molmoact2 import MolmoAct2Policy from lerobot.policies.molmoact2.processor_molmoact2 import ( + MolmoAct2ClampNormalizedProcessorStep, MolmoAct2MaskedNormalizerProcessorStep, MolmoAct2MaskedUnnormalizerProcessorStep, MolmoAct2PackInputsProcessorStep, @@ -185,8 +186,8 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path): FeatureType.STATE: NormalizationMode.QUANTILES, } transition = { - TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])}, - TransitionKey.ACTION: torch.tensor([[5.0, 7.0]]), + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 0.7]])}, + TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]), } normalizer = MolmoAct2MaskedNormalizerProcessorStep( features=features, @@ -195,17 +196,68 @@ def test_molmoact2_gripper_mask_uses_feature_names(tmp_path): ) normalized = normalizer(transition) - assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 7.0]])) - assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, 7.0]])) + assert torch.equal(normalized[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[0.0, 0.7]])) + assert torch.equal(normalized[TransitionKey.ACTION], torch.tensor([[0.0, -0.7]])) + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + normalizer( + { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[5.0, 7.0]])}, + TransitionKey.ACTION: torch.tensor([[5.0, -0.7]]), + } + ) unnormalizer = MolmoAct2MaskedUnnormalizerProcessorStep( features={ACTION: features[ACTION]}, norm_map=norm_map, stats=masked_stats, ) - unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, 7.0]])}) + unnormalized = unnormalizer({TransitionKey.ACTION: torch.tensor([[0.0, -0.7]])}) - assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, 7.0]])) + assert torch.equal(unnormalized[TransitionKey.ACTION], torch.tensor([[5.0, -0.7]])) + + +def test_molmoact2_gripper_mask_validates_dataset_stats(tmp_path): + meta_dir = tmp_path / "meta" + meta_dir.mkdir() + (meta_dir / "info.json").write_text( + json.dumps({"features": {ACTION: {"names": ["x", "gripper"]}}}), + encoding="utf-8", + ) + stats = { + ACTION: { + "min": [-0.5, -2.0], + "max": [0.5, 0.5], + } + } + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=False) + + masked_stats = _add_gripper_masks_to_stats(stats, SimpleNamespace(root=tmp_path), normalize_gripper=True) + assert masked_stats is not None + assert masked_stats[ACTION]["mask"] == [True, True] + + +def test_molmoact2_clamp_normalized_respects_masked_gripper_dims(): + step = MolmoAct2ClampNormalizedProcessorStep( + normalization_masks={ + ACTION: [True, False], + OBS_STATE: [True, False], + } + ) + transition = { + TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[-2.0, 0.8]])}, + TransitionKey.ACTION: torch.tensor([[2.0, -0.8]]), + } + + clamped = step(transition) + + assert torch.equal(clamped[TransitionKey.OBSERVATION][OBS_STATE], torch.tensor([[-1.0, 0.8]])) + assert torch.equal(clamped[TransitionKey.ACTION], torch.tensor([[1.0, -0.8]])) + + with pytest.raises(ValueError, match="gripper values are not under \\[-1, 1\\]"): + step({TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[0.0, 1.2]])}}) def test_molmoact2_normalize_gripper_true_keeps_all_dims_normalized(tmp_path):