preparing for training adding some temporary debug code aswell to visualize model output

This commit is contained in:
Maxime Ellerbach
2026-06-12 15:25:28 +00:00
parent 7c063c3fbc
commit a323ea67b6
6 changed files with 282 additions and 52 deletions
+12 -6
View File
@@ -49,6 +49,8 @@ def test_fastwam_is_registered_and_publicly_exported():
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
base_model_id=None,
)
@@ -78,6 +80,8 @@ def test_preprocessor_normalizes_images_and_postprocessor_toggles_actions(tmp_pa
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(2, 2),
device="cpu",
toggle_action_dimensions=[-1],
@@ -154,6 +158,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
proprio_dim=2,
action_horizon=4,
n_action_steps=2,
num_video_frames=5,
action_video_freq_ratio=1,
image_size=(16, 16),
input_features={
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 16, 16)),
@@ -164,7 +170,7 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
)
policy = FastWAMPolicy(cfg)
output = policy.forward(
loss, metrics = policy.forward(
{
"observation.images.image": torch.zeros(1, 3, 16, 16),
OBS_STATE: torch.zeros(1, 2),
@@ -186,8 +192,8 @@ def test_policy_forward_and_predict_action_adapt_lerobot_batches(monkeypatch):
}
)
assert output["loss"].item() == 1.0
assert output["loss_action"].item() == 1.0
assert loss.item() == 1.0
assert metrics["loss_action"] == 1.0
assert action.shape == (2, 4, 3)
assert action[:, 0, 0].tolist() == [1.0, 2.0]
assert [item["image_shape"] for item in captured] == [(1, 3, 16, 16), (1, 3, 16, 16)]
@@ -218,7 +224,7 @@ class CoreWithFrozenComponents(FakeFastWAMCore):
def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
def build_core(self, config):
core = CoreWithFrozenComponents()
@@ -250,7 +256,7 @@ def test_from_pretrained_uses_base_loader_and_skips_wan_backbone(monkeypatch, tm
def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)
@@ -272,7 +278,7 @@ def test_save_pretrained_excludes_frozen_components(monkeypatch, tmp_path):
def test_frozen_components_excluded_from_params_but_follow_device_moves(monkeypatch):
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, base_model_id=None)
cfg = FastWAMConfig(action_dim=3, proprio_dim=2, action_horizon=4, n_action_steps=2, num_video_frames=5, action_video_freq_ratio=1, base_model_id=None)
monkeypatch.setattr(FastWAMPolicy, "_build_core_model", lambda self, config: CoreWithFrozenComponents())
policy = FastWAMPolicy(cfg)