From f62cfc9ca23ad6c27cf8be6b3885bc9aea966012 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 25 Nov 2025 16:01:39 +0100 Subject: [PATCH] fix failing test --- src/lerobot/policies/xvla/modeling_xvla.py | 31 +------------------ src/lerobot/scripts/lerobot_train.py | 4 ++- .../xvla/test_xvla_original_vs_lerobot.py | 24 ++++++++++++++ 3 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/lerobot/policies/xvla/modeling_xvla.py b/src/lerobot/policies/xvla/modeling_xvla.py index cb1784e8f..8da58264f 100644 --- a/src/lerobot/policies/xvla/modeling_xvla.py +++ b/src/lerobot/policies/xvla/modeling_xvla.py @@ -121,19 +121,6 @@ class XVLAModel(nn.Module): for param in self.transformer.soft_prompt_hub.parameters(): param.requires_grad = False - def get_trainable_params_summary(self) -> dict[str, int]: - """ - Returns a summary of trainable vs frozen parameters. - """ - trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) - frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) - return { - "trainable": trainable, - "frozen": frozen, - "total": trainable + frozen, - "trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0, - } - def forward_vlm( self, input_ids: torch.LongTensor, @@ -248,17 +235,6 @@ class XVLAPolicy(PreTrainedPolicy): self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim) self.reset() - # Log trainable parameters summary - params_summary = self.model.get_trainable_params_summary() - print("XVLA Parameter Summary:") - print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)") - print(f" Frozen: {params_summary['frozen']:,}") - print(f" Total: {params_summary['total']:,}") - print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}") - print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}") - print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}") - print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}") - def reset(self) -> None: self._queues = { ACTION: deque(maxlen=self.config.n_action_steps), @@ -472,13 +448,8 @@ class XVLAPolicy(PreTrainedPolicy): state_dict[shared_key] = state_dict[encoder_key] # or deepcopy # step 4: load into instance - missing, unexpected = instance.load_state_dict(state_dict, strict=True) + instance.load_state_dict(state_dict, strict=True) print("Loaded XVLA checkpoint") - if missing: - print(f"Missing keys: {missing}") - if unexpected: - print(f"Unexpected keys: {unexpected}") - # step 5: finalize instance.to(config.device) instance.eval() diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index fd200a254..92fd1917d 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -260,7 +260,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.env is not None: logging.info(f"{cfg.env.task=}") logging.info("Creating environment processors") - env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) + env_preprocessor, env_postprocessor = make_env_pre_post_processors( + env_cfg=cfg.env, policy_cfg=cfg.policy + ) logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_episodes=}") diff --git a/tests/policies/xvla/test_xvla_original_vs_lerobot.py b/tests/policies/xvla/test_xvla_original_vs_lerobot.py index 9e803b41e..9baaa85cd 100644 --- a/tests/policies/xvla/test_xvla_original_vs_lerobot.py +++ b/tests/policies/xvla/test_xvla_original_vs_lerobot.py @@ -147,6 +147,30 @@ def create_dummy_data(device=DEVICE): return batch +# Pytest fixtures +@pytest.fixture(scope="module") +def xvla_components(): + """Fixture to instantiate and provide all XVLA components for tests.""" + print(f"\nTesting with DEVICE='{DEVICE}'") + print("\n[Setup] Instantiating LeRobot XVLA policy...") + policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True) + print("✔️ Model loaded successfully") + yield policy_obj, preprocessor_obj, postprocessor_obj + cleanup_memory() + + +@pytest.fixture(scope="module") +def policy(xvla_components): + """Fixture to provide the XVLA policy for tests.""" + return xvla_components[0] + + +@pytest.fixture(scope="module") +def preprocessor(xvla_components): + """Fixture to provide the XVLA preprocessor for tests.""" + return xvla_components[1] + + def test_xvla_preprocessor_alignment(policy, preprocessor): """Test that LeRobot XVLA preprocessor produces expected outputs.""" print("\n" + "=" * 80)