This commit is contained in:
Pepijn
2025-08-28 19:23:17 +02:00
parent bead25a58a
commit cc05067a76
6 changed files with 137 additions and 207 deletions
+25 -18
View File
@@ -25,10 +25,12 @@ from tests.utils import require_package
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_instantiation_and_forward_tensor_batch():
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False,
freeze_backbones=True,
)
@@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch():
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_instantiation_and_forward_list_batch_with_language():
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False,
freeze_backbones=True,
)
@@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language():
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_composite_loss_shapes_and_terms():
"""Smoke test composite loss: checks presence of terms and valid gradients."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
push_to_hub=False,
freeze_backbones=True,
loss_type="composite",
lambda_prog=1.0,
lambda_spatial_nce=0.5,
lambda_rewind=0.4,
num_ranking_pairs=32, # Fewer pairs for testing
last_k_for_nce=2,
use_video_rewind=True,
rewind_prob=0.5,
use_mismatch_loss=True,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
@@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms():
loss, logs = policy.forward(batch)
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
# Expect composite terms present with spatial awareness and ReWiND
assert "loss_prog" in logs
assert "loss_spatial_nce" in logs
assert "loss_rewind_forward" in logs
assert "loss_rewind_reverse" in logs
# Expect ReWiND loss terms (progress and mismatch)
assert "loss_progress" in logs
assert "loss_mismatch" in logs
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_preprocessor_tokenizes_and_copies_task():
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu",
push_to_hub=False,
)
@@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task():
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_preprocessor_string_task_and_to_batch():
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu",
push_to_hub=False,
)
@@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch():
@require_package("transformers")
@require_package("sentence_transformers")
def test_rlearn_pipeline_end_to_end_forward():
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
cfg = RLearNConfig(
model_name="google/siglip2-large-patch16-256",
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
device="cpu",
push_to_hub=False,
freeze_backbones=True,
loss_type="composite",
use_video_rewind=True,
)
cfg.input_features = {
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),