mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Tests for training-time RTC helpers."""
|
|
|
|
import torch
|
|
|
|
from lerobot.configs.types import RTCTrainingDelayDistribution
|
|
from lerobot.policies.rtc.configuration_rtc import RTCTrainingConfig
|
|
from lerobot.policies.rtc.training_time import apply_rtc_training_time, sample_rtc_delay
|
|
|
|
|
|
def test_rtc_training_config_defaults():
|
|
config = RTCTrainingConfig()
|
|
assert config.enabled is False
|
|
assert config.min_delay == 0
|
|
assert config.max_delay == 0
|
|
assert config.delay_distribution == RTCTrainingDelayDistribution.UNIFORM
|
|
assert config.exp_decay == 1.0
|
|
|
|
|
|
def test_sample_rtc_delay_uniform_range():
|
|
cfg = RTCTrainingConfig(enabled=True, min_delay=1, max_delay=4)
|
|
delays = sample_rtc_delay(cfg, batch_size=100, device=torch.device("cpu"))
|
|
assert delays.min().item() >= 1
|
|
assert delays.max().item() <= 4
|
|
|
|
|
|
def test_apply_rtc_training_time_prefix_mask():
|
|
time = torch.tensor([0.5])
|
|
delays = torch.tensor([2])
|
|
time_tokens, postfix_mask = apply_rtc_training_time(time, delays, seq_len=4)
|
|
assert time_tokens.shape == (1, 4)
|
|
assert postfix_mask.shape == (1, 4)
|
|
# Delay=2 means the first two steps are prefix (time forced to 0.0) and only the last two are postfix.
|
|
assert torch.allclose(time_tokens[0], torch.tensor([0.0, 0.0, 0.5, 0.5]))
|
|
assert torch.equal(postfix_mask[0], torch.tensor([False, False, True, True]))
|