mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors (#3128)
* fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors
When VQ discretization phase completes, the code was overwriting
register_buffer('discretized') and register_buffer('freeze_codebook')
with torch.tensor(True), which is created on CPU. DDP then fails in
_sync_buffers() with: RuntimeError: No backend type associated with
device type cpu. Fix by updating the buffers in-place with .fill_(True)
so device and registration are preserved.
Made-with: Cursor
* test(vqbet): add regression test for in-place buffer update during discretization
Verifies that discretize() updates the 'discretized' and 'freeze_codebook'
registered buffers in-place (via fill_()) rather than replacing them with new
CPU tensors. The test checks data_ptr() identity and that the tensors remain
registered buffers after the call. This prevents regressions of the DDP fix.
Made-with: Cursor
* test(vqbet): add GPU regression test to verify buffers stay on CUDA after discretize()
Directly catches the original DDP failure mode: when buffers are replaced with
torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type
associated with device type cpu' in _sync_buffers(). The GPU test places the
model on cuda:0 and asserts both buffers remain on CUDA after discretization.
Made-with: Cursor
* test(vqbet): simplify to single device-check test in test_policies.py
Per reviewer feedback: remove the separate test file and replace the two
CPU/GPU tests (with data_ptr checks) with a single focused test in
tests/policies/test_policies.py that only asserts the registered buffers
remain on the model device after discretize(). Uses DEVICE from tests/utils.py
so it runs on whatever device the CI/user selects (cpu, cuda, mps).
Made-with: Cursor
* style: fix import order in test_policies.py to pass ruff/pre-commit checks
Made-with: Cursor
---------
Co-authored-by: Zhan DiJia <2476100824@example.com>
Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>
This commit is contained in:
@@ -467,8 +467,8 @@ class VQBeTHead(nn.Module):
|
||||
self.vqvae_model.optimized_steps += 1
|
||||
# if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part.
|
||||
if self.vqvae_model.optimized_steps >= n_vqvae_training_steps:
|
||||
self.vqvae_model.discretized = torch.tensor(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)
|
||||
self.vqvae_model.discretized.fill_(True)
|
||||
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
|
||||
@@ -42,6 +42,8 @@ from lerobot.policies.factory import (
|
||||
make_pre_post_processors,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTHead
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
@@ -460,3 +462,45 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
torch.testing.assert_close(online_avg, offline_avg, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_vqbet_discretize_keeps_buffers_on_device():
|
||||
"""Regression test: VQBeTHead.discretize() must not move registered buffers off the model device.
|
||||
|
||||
Previously, `self.vqvae_model.discretized = torch.tensor(True)` replaced the
|
||||
registered buffer with a new CPU tensor, causing DDP to crash with:
|
||||
RuntimeError: No backend type associated with device type cpu
|
||||
The fix uses `.fill_(True)` to update in-place, preserving device placement.
|
||||
"""
|
||||
config = VQBeTConfig()
|
||||
config.input_features = {
|
||||
OBS_IMAGES: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 96, 96)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(6,)),
|
||||
}
|
||||
config.output_features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
|
||||
}
|
||||
# Tiny sizes for fast CPU/GPU execution.
|
||||
config.n_vqvae_training_steps = 3
|
||||
config.vqvae_n_embed = 8
|
||||
config.vqvae_embedding_dim = 32
|
||||
config.vqvae_enc_hidden_dim = 32
|
||||
config.action_chunk_size = 2
|
||||
config.crop_shape = (84, 84)
|
||||
|
||||
head = VQBeTHead(config).to(DEVICE)
|
||||
vqvae = head.vqvae_model
|
||||
|
||||
dummy_actions = torch.randn(4, config.action_chunk_size, config.action_feature.shape[0], device=DEVICE)
|
||||
n_steps = config.n_vqvae_training_steps
|
||||
for _ in range(n_steps):
|
||||
head.discretize(n_steps, dummy_actions)
|
||||
|
||||
assert vqvae.discretized.device.type == torch.device(DEVICE).type, (
|
||||
"vqvae_model.discretized was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
assert vqvae.vq_layer.freeze_codebook.device.type == torch.device(DEVICE).type, (
|
||||
"vq_layer.freeze_codebook was moved off the model device after discretize(). "
|
||||
"Use .fill_(True) instead of = torch.tensor(True) to keep the buffer on device."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user