mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix failing test
This commit is contained in:
@@ -121,19 +121,6 @@ class XVLAModel(nn.Module):
|
||||
for param in self.transformer.soft_prompt_hub.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_trainable_params_summary(self) -> dict[str, int]:
|
||||
"""
|
||||
Returns a summary of trainable vs frozen parameters.
|
||||
"""
|
||||
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad)
|
||||
return {
|
||||
"trainable": trainable,
|
||||
"frozen": frozen,
|
||||
"total": trainable + frozen,
|
||||
"trainable_pct": 100.0 * trainable / (trainable + frozen) if (trainable + frozen) > 0 else 0.0,
|
||||
}
|
||||
|
||||
def forward_vlm(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -248,17 +235,6 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||
self.reset()
|
||||
|
||||
# Log trainable parameters summary
|
||||
params_summary = self.model.get_trainable_params_summary()
|
||||
print("XVLA Parameter Summary:")
|
||||
print(f" Trainable: {params_summary['trainable']:,} ({params_summary['trainable_pct']:.2f}%)")
|
||||
print(f" Frozen: {params_summary['frozen']:,}")
|
||||
print(f" Total: {params_summary['total']:,}")
|
||||
print(f" Vision Encoder: {'Frozen' if config.freeze_vision_encoder else 'Trainable'}")
|
||||
print(f" Language Encoder: {'Frozen' if config.freeze_language_encoder else 'Trainable'}")
|
||||
print(f" Policy Transformer: {'Trainable' if config.train_policy_transformer else 'Frozen'}")
|
||||
print(f" Soft Prompts: {'Trainable' if config.train_soft_prompts else 'Frozen'}")
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
@@ -472,13 +448,8 @@ class XVLAPolicy(PreTrainedPolicy):
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# or deepcopy
|
||||
# step 4: load into instance
|
||||
missing, unexpected = instance.load_state_dict(state_dict, strict=True)
|
||||
instance.load_state_dict(state_dict, strict=True)
|
||||
print("Loaded XVLA checkpoint")
|
||||
if missing:
|
||||
print(f"Missing keys: {missing}")
|
||||
if unexpected:
|
||||
print(f"Unexpected keys: {unexpected}")
|
||||
|
||||
# step 5: finalize
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
|
||||
@@ -260,7 +260,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info("Creating environment processors")
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy
|
||||
)
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
|
||||
@@ -147,6 +147,30 @@ def create_dummy_data(device=DEVICE):
|
||||
return batch
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture(scope="module")
|
||||
def xvla_components():
|
||||
"""Fixture to instantiate and provide all XVLA components for tests."""
|
||||
print(f"\nTesting with DEVICE='{DEVICE}'")
|
||||
print("\n[Setup] Instantiating LeRobot XVLA policy...")
|
||||
policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_xvla(from_pretrained=True)
|
||||
print("✔️ Model loaded successfully")
|
||||
yield policy_obj, preprocessor_obj, postprocessor_obj
|
||||
cleanup_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def policy(xvla_components):
|
||||
"""Fixture to provide the XVLA policy for tests."""
|
||||
return xvla_components[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def preprocessor(xvla_components):
|
||||
"""Fixture to provide the XVLA preprocessor for tests."""
|
||||
return xvla_components[1]
|
||||
|
||||
|
||||
def test_xvla_preprocessor_alignment(policy, preprocessor):
|
||||
"""Test that LeRobot XVLA preprocessor produces expected outputs."""
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
Reference in New Issue
Block a user