fix failing test

This commit is contained in:
Jade Choghari
2025-11-25 16:01:39 +01:00
parent 829428ac81
commit f62cfc9ca2
3 changed files with 28 additions and 31 deletions
+1 -30
View File
@@ -121,19 +121,6 @@ class XVLAModel(nn.Module):
for param in self.transformer.soft_prompt_hub.parameters(): for param in self.transformer.soft_prompt_hub.parameters():
param.requires_grad = False param.requires_grad = False
def get_trainable_params_summary(self) -> dict[str, int]:
"""
Returns a summary of trainable vs frozen parameters.
"""
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
return {
"trainable": trainable,
"frozen": frozen,
"total": trainable + frozen,
"trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0,
}
def forward_vlm( def forward_vlm(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
@@ -248,17 +235,6 @@ class XVLAPolicy(PreTrainedPolicy):
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim) self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
self.reset() self.reset()
# Log trainable parameters summary
params_summary = self.model.get_trainable_params_summary()
print("XVLA Parameter Summary:")
print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)")
print(f" Frozen: {params_summary['frozen']:,}")
print(f" Total: {params_summary['total']:,}")
print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}")
print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}")
print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}")
print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}")
def reset(self) -> None: def reset(self) -> None:
self._queues = { self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps), ACTION: deque(maxlen=self.config.n_action_steps),
@@ -472,13 +448,8 @@ class XVLAPolicy(PreTrainedPolicy):
state_dict[shared_key] = state_dict[encoder_key] state_dict[shared_key] = state_dict[encoder_key]
# or deepcopy # or deepcopy
# step 4: load into instance # step 4: load into instance
missing, unexpected = instance.load_state_dict(state_dict, strict=True) instance.load_state_dict(state_dict, strict=True)
print("Loaded XVLA checkpoint") print("Loaded XVLA checkpoint")
if missing:
print(f"Missing keys: {missing}")
if unexpected:
print(f"Unexpected keys: {unexpected}")
# step 5: finalize # step 5: finalize
instance.to(config.device) instance.to(config.device)
instance.eval() instance.eval()
+3 -1
View File
@@ -260,7 +260,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.env is not None: if cfg.env is not None:
logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.env.task=}")
logging.info("Creating environment processors") logging.info("Creating environment processors")
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env) env_preprocessor, env_postprocessor = make_env_pre_post_processors(
env_cfg=cfg.env, policy_cfg=cfg.policy
)
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}") logging.info(f"{dataset.num_episodes=}")
@@ -147,6 +147,30 @@ def create_dummy_data(device=DEVICE):
return batch return batch
# Pytest fixtures
@pytest.fixture(scope="module")
def xvla_components():
"""Fixture to instantiate and provide all XVLA components for tests."""
print(f"\nTesting with DEVICE='{DEVICE}'")
print("\n[Setup] Instantiating LeRobot XVLA policy...")
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
print("✔️ Model loaded successfully")
yield policy_obj, preprocessor_obj, postprocessor_obj
cleanup_memory()
@pytest.fixture(scope="module")
def policy(xvla_components):
"""Fixture to provide the XVLA policy for tests."""
return xvla_components[0]
@pytest.fixture(scope="module")
def preprocessor(xvla_components):
"""Fixture to provide the XVLA preprocessor for tests."""
return xvla_components[1]
def test_xvla_preprocessor_alignment(policy, preprocessor): def test_xvla_preprocessor_alignment(policy, preprocessor):
"""Test that LeRobot XVLA preprocessor produces expected outputs.""" """Test that LeRobot XVLA preprocessor produces expected outputs."""
print("\n" + "=" * 80) print("\n" + "=" * 80)