mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
dino v2
This commit is contained in:
@@ -25,10 +25,12 @@ from tests.utils import require_package
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||
"""Instantiate RLearN and run a forward pass with a (B, T, C, H, W) tensor input using a real model and real text."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
@@ -54,10 +56,12 @@ def test_rlearn_instantiation_and_forward_tensor_batch():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||
"""Instantiate RLearN and run a forward pass with a list-of-frames input and real language using a real model."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
)
|
||||
@@ -84,18 +88,17 @@ def test_rlearn_instantiation_and_forward_list_batch_with_language():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_composite_loss_shapes_and_terms():
|
||||
"""Smoke test composite loss: checks presence of terms and valid gradients."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
lambda_prog=1.0,
|
||||
lambda_spatial_nce=0.5,
|
||||
lambda_rewind=0.4,
|
||||
num_ranking_pairs=32, # Fewer pairs for testing
|
||||
last_k_for_nce=2,
|
||||
use_video_rewind=True,
|
||||
rewind_prob=0.5,
|
||||
use_mismatch_loss=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
@@ -117,17 +120,17 @@ def test_rlearn_composite_loss_shapes_and_terms():
|
||||
|
||||
loss, logs = policy.forward(batch)
|
||||
assert isinstance(loss, torch.Tensor) and torch.isfinite(loss)
|
||||
# Expect composite terms present with spatial awareness and ReWiND
|
||||
assert "loss_prog" in logs
|
||||
assert "loss_spatial_nce" in logs
|
||||
assert "loss_rewind_forward" in logs
|
||||
assert "loss_rewind_reverse" in logs
|
||||
# Expect ReWiND loss terms (progress and mismatch)
|
||||
assert "loss_progress" in logs
|
||||
assert "loss_mismatch" in logs
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
@@ -161,9 +164,11 @@ def test_rlearn_preprocessor_tokenizes_and_copies_task():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
)
|
||||
@@ -194,14 +199,16 @@ def test_rlearn_preprocessor_string_task_and_to_batch():
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_package("sentence_transformers")
|
||||
def test_rlearn_pipeline_end_to_end_forward():
|
||||
"""End-to-end: preprocessor + model forward using RLearN pipeline on synthetic data."""
|
||||
cfg = RLearNConfig(
|
||||
model_name="google/siglip2-large-patch16-256",
|
||||
vision_model_name="facebook/dinov3-vitb16-pretrain-lvd1689m",
|
||||
text_model_name="sentence-transformers/all-MiniLM-L12-v2",
|
||||
device="cpu",
|
||||
push_to_hub=False,
|
||||
freeze_backbones=True,
|
||||
loss_type="composite",
|
||||
use_video_rewind=True,
|
||||
)
|
||||
cfg.input_features = {
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
|
||||
Reference in New Issue
Block a user