Files
lerobot/tests/rl/test_trainer.py
T
Khalil Meftah 8d50be9faa refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring
- Add RLAlgorithm base class and RLAlgorithmConfig with draccus.ChoiceRegistry
- Add RLTrainer for unified training orchestration with iterator pattern
- Add DataMixer and OnlineOfflineMixer for online/offline data mixing
- Restructure SAC algorithm with batch iterator and factory pattern
- Add observation normalization pre/post processors
- Add comprehensive tests for all new components
2026-03-03 16:50:00 +01:00

116 lines
3.6 KiB
Python

#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from lerobot.rl.algorithms.base import RLAlgorithm, TrainingStats
from lerobot.rl.trainer import RLTrainer
from lerobot.utils.constants import ACTION, OBS_STATE
class _CountingAlgorithm(RLAlgorithm):
def __init__(self):
self.configure_calls = 0
self.update_calls = 0
def select_action(self, observation: dict[str, Tensor]) -> Tensor:
return torch.zeros(1)
def configure_data_iterator(
self,
data_mixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
):
self.configure_calls += 1
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
def make_optimizers(self):
return {}
def update(self, batch_iterator):
self.update_calls += 1
_ = next(batch_iterator)
return TrainingStats(losses={"dummy": 1.0})
def load_weights(self, weights, device="cpu") -> None:
_ = (weights, device)
class _SimpleMixer:
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
_ = (async_prefetch, queue_size)
while True:
yield {
"state": {OBS_STATE: torch.randn(batch_size, 3)},
ACTION: torch.randn(batch_size, 2),
"reward": torch.randn(batch_size),
"next_state": {OBS_STATE: torch.randn(batch_size, 3)},
"done": torch.zeros(batch_size),
"truncated": torch.zeros(batch_size),
"complementary_info": None,
}
def test_trainer_lazy_iterator_lifecycle_and_reset():
algo = _CountingAlgorithm()
mixer = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4, async_prefetch=False)
# First call builds iterator once.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 1
# Second call reuses existing iterator.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 2
# Explicit reset forces lazy rebuild on next step.
trainer.reset_data_iterator()
trainer.training_step()
assert algo.configure_calls == 2
assert algo.update_calls == 3
def test_trainer_set_data_mixer_resets_by_default():
algo = _CountingAlgorithm()
mixer_a = _SimpleMixer()
mixer_b = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2, async_prefetch=False)
trainer.training_step()
assert algo.configure_calls == 1
trainer.set_data_mixer(mixer_b, reset=True)
trainer.training_step()
assert algo.configure_calls == 2
def test_algorithm_optimization_step_contract_defaults():
algo = _CountingAlgorithm()
assert algo.optimization_step == 0
algo.optimization_step = 11
assert algo.optimization_step == 11