Files
lerobot/tests/policies/rtc/test_training_time_rtc.py
T
2026-01-20 20:08:28 +01:00

51 lines
2.0 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for training-time RTC helpers."""
import torch
from lerobot.configs.types import RTCTrainingDelayDistribution
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
def test_rtc_training_config_defaults():
config = RTCTrainingConfig()
assert config.enabled is False
assert config.min_delay == 0
assert config.max_delay == 0
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
assert config.exp_decay == 1.0
def test_sample_rtc_delay_uniform_range():
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
assert delays.min().item() >= 1
assert delays.max().item() <= 4
def test_apply_rtc_training_time_prefix_mask():
time = torch.tensor([0.5])
delays = torch.tensor([2])
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
assert time_tokens.shape == (1, 4)
assert postfix_mask.shape == (1, 4)
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))