mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
f0d2b37beb
* chore(dependencies): upgrade transformers + hggingface-hub + peft + scipy * chore(dependencies): bump pi0 family to transformers v5 * chore(dependencies): bump wall x to transformers v5 * chore(dependencies): bump gr00t to transformers v5 * chore(style): fix pre-commit * fix(policy): xvla forced_bos_token missing * test(rl): skip ci tests for resnet10 * Fix: full pi models support for transformer v5 (#2967) * fix(pi): remove loss truncation * fix(pi): remove state padding before tokenization * fix(pi): fix image padding value * fix from_pretrain * add transformer v5 changes * remove reference * more fixes * make it work * add support for rest of pi family * add pifast work * more changes * more changes * more cleanup * fix torch params * dtype fix * torch compile * embed mismatch fix * revert groot * more nit fixes * remove unused classes * more fixes * revert * nit * torch dtype warning fix * but back dynamic renaming * add tie embedding --------- Co-authored-by: Yufei Sun <skieyfly@gmail.com> * chore: fix XVLA in transformers v5 (#3006) * test(policies): enable wall x CI testing * style(test): pre-commit check * style(test): pre-commit * fix wall x for transformer v5 (#3008) * tv5 fix * various wall x fixes * Delete tests/policies/pi0_pi05/print_pi05_output_logits.py Signed-off-by: Jade Choghari <chogharijade@gmail.com> * sync modeling_florence2.py with chore/bump_transformers_v5 * more * more fixes * more * remove comment * more --------- Signed-off-by: Jade Choghari <chogharijade@gmail.com> * chore(dependencies): adjust dependencies versioning after transformers v5 (#3034) * chore(dependecies): adjust dependecies versioning after transformers v5 * fix(policies): remove deprecated input_embeds * fix(policies): dict _tied_weights_keys * chore(depedencies): common qwen-vl-utils * chore(dependencies): bump transformers to 5.2 * Fix policy testing for tv5 (#3032) * fix ci logger * other fix * fix mypy * change logits to torch2.10 * skip wallx| * remove logging --------- Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> * feat(ci): log into HF to unblock some CI tests (#3007) * feat(ci): log into HF to unblock some CI tests * chore(ci): change hf call + secret name * fix(ci): temp fix for pi0 rtc test * test(policies): require_cuda for unblocked tests * test(policies): require_cuda wall_x * fic(tests): require_cuda outter most for pi0 * fix(test): return instead of yield --------- Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> * style(test): fix pre-commit * chore(deps): upgrade transformers (#3050) * chore(test): use lerobot model * fix(policies): change default action tokenizer for wall x * sample on cpu * Revert "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5" This reverts commitd9b76755f7, reversing changes made to89359cb0b6. * Reapply "Merge branch 'chore/bump_transformers_v5' of https://github.com/huggingface/lerobot into chore/bump_transformers_v5" This reverts commitc9914db78b. --------- Signed-off-by: Jade Choghari <chogharijade@gmail.com> Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: Yufei Sun <skieyfly@gmail.com> Co-authored-by: Pepijn <pepijn@huggingface.co>
381 lines
13 KiB
Python
381 lines
13 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference."""
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# Skip this entire module in CI
|
|
pytestmark = pytest.mark.skipif(
|
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
reason="TODO: This test seems to hang the CI",
|
|
)
|
|
|
|
|
|
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
|
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
|
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
|
from lerobot.utils.random_utils import set_seed # noqa: E402
|
|
from tests.utils import require_cuda # noqa: E402
|
|
|
|
|
|
@require_cuda
|
|
def test_pi0_rtc_initialization():
|
|
"""Test PI0 policy can initialize RTC processor."""
|
|
set_seed(42)
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
|
|
|
# Add RTC config
|
|
config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=10,
|
|
max_guidance_weight=5.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
debug=False,
|
|
)
|
|
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
}
|
|
|
|
# Instantiate policy
|
|
policy = PI0Policy(config)
|
|
|
|
# Verify RTC processor is initialized
|
|
assert hasattr(policy, "rtc_processor")
|
|
assert policy.rtc_processor is not None
|
|
assert policy.rtc_processor.rtc_config.enabled is True
|
|
|
|
print("✓ PI0 RTC initialization: Test passed")
|
|
|
|
|
|
@require_cuda
|
|
def test_pi0_rtc_initialization_without_rtc_config():
|
|
"""Test PI0 policy can initialize without RTC config."""
|
|
set_seed(42)
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
|
|
|
# Instantiate policy
|
|
policy = PI0Policy(config)
|
|
|
|
# Verify RTC processor is not initialized
|
|
assert hasattr(policy, "rtc_processor")
|
|
assert policy.rtc_processor is None
|
|
assert policy.model.rtc_processor is None
|
|
assert policy._rtc_enabled() is False
|
|
|
|
print("✓ PI0 RTC initialization without RTC config: Test passed")
|
|
|
|
|
|
@require_cuda
|
|
def test_pi0_rtc_inference_with_prev_chunk():
|
|
"""Test PI0 policy inference with RTC and previous chunk."""
|
|
set_seed(42)
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
|
|
|
# Add RTC config
|
|
config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=10,
|
|
max_guidance_weight=5.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
debug=False,
|
|
)
|
|
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
}
|
|
|
|
# Create dataset stats
|
|
dataset_stats = {
|
|
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
|
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
|
}
|
|
|
|
# Instantiate policy and preprocessor
|
|
policy = PI0Policy(config)
|
|
policy.eval()
|
|
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
|
|
|
device = config.device
|
|
|
|
# Create dummy batch
|
|
batch = {
|
|
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
|
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
|
"task": ["Pick up the object"],
|
|
}
|
|
batch = preprocessor(batch)
|
|
|
|
# Create previous chunk
|
|
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
|
|
|
with torch.no_grad():
|
|
# Use same noise for fair comparison
|
|
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
|
|
|
# Test with RTC and previous chunk
|
|
actions_with_rtc = policy.predict_action_chunk(
|
|
batch,
|
|
noise=noise.clone(),
|
|
prev_chunk_left_over=prev_chunk,
|
|
inference_delay=4,
|
|
execution_horizon=10,
|
|
)
|
|
|
|
# Test without RTC for comparison
|
|
policy.config.rtc_config.enabled = False
|
|
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
|
policy.config.rtc_config.enabled = True
|
|
|
|
# Verify shapes
|
|
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
|
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
|
|
|
# With previous chunk, actions should be different (RTC guidance applied)
|
|
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
|
|
|
print("✓ PI0 RTC inference with prev_chunk: Test passed")
|
|
|
|
|
|
@require_cuda
|
|
def test_pi0_rtc_inference_without_prev_chunk():
|
|
"""Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
|
set_seed(42)
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
|
|
|
# Add RTC config
|
|
config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=10,
|
|
max_guidance_weight=5.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
debug=False,
|
|
)
|
|
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
}
|
|
|
|
# Create dataset stats
|
|
dataset_stats = {
|
|
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
|
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
|
}
|
|
|
|
# Instantiate policy and preprocessor
|
|
policy = PI0Policy(config)
|
|
policy.eval()
|
|
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
|
|
|
device = config.device
|
|
|
|
# Create dummy batch
|
|
batch = {
|
|
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
|
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
|
"task": ["Pick up the object"],
|
|
}
|
|
batch = preprocessor(batch)
|
|
|
|
with torch.no_grad():
|
|
# Use same noise for fair comparison
|
|
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
|
|
|
# Test with RTC enabled but no previous chunk
|
|
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
|
batch,
|
|
noise=noise.clone(),
|
|
prev_chunk_left_over=None,
|
|
)
|
|
|
|
# Test without RTC
|
|
policy.config.rtc_config.enabled = False
|
|
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
|
policy.config.rtc_config.enabled = True
|
|
|
|
# Without previous chunk, RTC should have no effect
|
|
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
|
|
|
print("✓ PI0 RTC inference without prev_chunk: Test passed")
|
|
|
|
|
|
@require_cuda
|
|
def test_pi0_rtc_validation_rules():
|
|
"""Test PI0 policy with RTC follows all three validation rules."""
|
|
set_seed(42)
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
|
|
|
# Add RTC config
|
|
config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=10,
|
|
max_guidance_weight=5.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
debug=False,
|
|
)
|
|
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
}
|
|
|
|
# Create dataset stats
|
|
dataset_stats = {
|
|
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
|
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
|
}
|
|
|
|
# Instantiate policy and preprocessor
|
|
policy = PI0Policy(config)
|
|
policy.eval()
|
|
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
|
|
|
device = config.device
|
|
|
|
# Create dummy batch
|
|
batch = {
|
|
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
|
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
|
"task": ["Pick up the object"],
|
|
}
|
|
batch = preprocessor(batch)
|
|
|
|
# Create previous chunk
|
|
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
|
|
|
inference_delay = 4
|
|
execution_horizon = 10
|
|
|
|
with torch.no_grad():
|
|
# Use same noise for fair comparison
|
|
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
|
|
|
# Test with RTC
|
|
actions_with_rtc = policy.predict_action_chunk(
|
|
batch,
|
|
noise=noise.clone(),
|
|
prev_chunk_left_over=prev_chunk,
|
|
inference_delay=inference_delay,
|
|
execution_horizon=execution_horizon,
|
|
)
|
|
|
|
# Test without RTC
|
|
policy.config.rtc_config.enabled = False
|
|
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
|
policy.config.rtc_config.enabled = True
|
|
|
|
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
|
|
|
"""Test PI0 with different RTC attention schedules."""
|
|
set_seed(42)
|
|
|
|
schedules = [
|
|
RTCAttentionSchedule.ZEROS,
|
|
RTCAttentionSchedule.ONES,
|
|
RTCAttentionSchedule.LINEAR,
|
|
RTCAttentionSchedule.EXP,
|
|
]
|
|
|
|
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
|
|
|
config.input_features = {
|
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
|
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
|
}
|
|
config.output_features = {
|
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
|
}
|
|
|
|
# Create dataset stats
|
|
dataset_stats = {
|
|
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
|
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
|
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
|
}
|
|
|
|
device = config.device
|
|
|
|
for schedule in schedules:
|
|
print(f"Testing schedule: {schedule}")
|
|
|
|
# Add RTC config with specific schedule
|
|
config.rtc_config = RTCConfig(
|
|
enabled=True,
|
|
execution_horizon=10,
|
|
max_guidance_weight=5.0,
|
|
prefix_attention_schedule=schedule,
|
|
debug=False,
|
|
)
|
|
|
|
# Instantiate policy
|
|
policy = PI0Policy(config)
|
|
policy.eval()
|
|
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
|
|
|
# Create dummy batch
|
|
batch = {
|
|
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
|
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
|
"task": ["Pick up the object"],
|
|
}
|
|
batch = preprocessor(batch)
|
|
|
|
# Create previous chunk
|
|
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
|
|
|
with torch.no_grad():
|
|
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
|
actions = policy.predict_action_chunk(
|
|
batch,
|
|
noise=noise,
|
|
prev_chunk_left_over=prev_chunk,
|
|
inference_delay=4,
|
|
execution_horizon=10,
|
|
)
|
|
|
|
# Verify shape
|
|
assert actions.shape == (1, config.chunk_size, 7)
|
|
print(f" ✓ Schedule {schedule}: Test passed")
|
|
|
|
print("✓ PI0 RTC different schedules: All schedules tested")
|