refactor: simplify docstrings for clarity and conciseness across multiple files

This commit is contained in:
Khalil Meftah
2026-04-28 11:11:02 +02:00
parent e298474bf3
commit ef6b3b5b0f
7 changed files with 7 additions and 53 deletions
@@ -36,15 +36,6 @@ DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
class GaussianActorPolicy( class GaussianActorPolicy(
PreTrainedPolicy, PreTrainedPolicy,
): ):
"""Gaussian actor + observation encoder.
Policy-side ``nn.Module`` used by SAC and related maximum-entropy continuous
control algorithms. It owns the actor network (``Policy``) and the observation
encoder (``GaussianActorObservationEncoder``); the critics, temperature, and
Bellman-update logic live on the algorithm side
(see ``lerobot.rl.algorithms.sac``).
"""
config_class = GaussianActorConfig config_class = GaussianActorConfig
name = "gaussian_actor" name = "gaussian_actor"
+1 -8
View File
@@ -291,11 +291,7 @@ def act_with_policy(
with policy_timer: with policy_timer:
normalized_observation = preprocessor.process_observation(observation) normalized_observation = preprocessor.process_observation(observation)
action = policy.select_action(batch=normalized_observation) action = policy.select_action(batch=normalized_observation)
# Unnormalize only the continuous part. When `num_discrete_actions` is set, # Unnormalize only the continuous part.
# `select_action` concatenates an argmax index in env space at the last dim;
# action stats cover the continuous dims only, so feeding the full vector to
# the unnormalizer would shape-mismatch and would also corrupt the discrete
# index by treating it as a normalized value.
if cfg.policy.num_discrete_actions is not None: if cfg.policy.num_discrete_actions is not None:
continuous_action = postprocessor.process_action(action[..., :-1]) continuous_action = postprocessor.process_action(action[..., :-1])
discrete_action = action[..., -1:].to( discrete_action = action[..., -1:].to(
@@ -346,9 +342,6 @@ def act_with_policy(
"discrete_penalty": torch.tensor( "discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)] [new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
), ),
# Forward the intervention flag so the learner can route this transition
# into the offline replay buffer (see `process_transitions` in learner.py).
# Use the plain string key so the payload survives torch.load(weights_only=True).
TeleopEvents.IS_INTERVENTION.value: is_intervention, TeleopEvents.IS_INTERVENTION.value: is_intervention,
} }
# Create transition for learner (convert to old format) # Create transition for learner (convert to old format)
+1 -1
View File
@@ -21,7 +21,7 @@ from lerobot.rl.algorithms.configs import RLAlgorithmConfig
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig: def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
"""Instantiate an :class:`RLAlgorithmConfig` from its registered type name. """Instantiate an `RLAlgorithmConfig` from its registered type name.
Args: Args:
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``). algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
+3 -19
View File
@@ -21,11 +21,7 @@ from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
class DataMixer(abc.ABC): class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies. """Abstract interface for all data mixing strategies."""
Subclasses must implement ``sample(batch_size)`` and may override
``get_iterator`` for specialised iteration.
"""
@abc.abstractmethod @abc.abstractmethod
def sample(self, batch_size: int) -> BatchType: def sample(self, batch_size: int) -> BatchType:
@@ -38,25 +34,13 @@ class DataMixer(abc.ABC):
async_prefetch: bool = True, async_prefetch: bool = True,
queue_size: int = 2, queue_size: int = 2,
): ):
"""Infinite iterator that yields batches. """Infinite iterator that yields batches."""
The default implementation repeatedly calls ``self.sample()``.
Subclasses with underlying buffer iterators (async prefetch)
should override this for better throughput.
"""
while True: while True:
yield self.sample(batch_size) yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer): class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an optional offline replay buffer. """Mixes transitions from an online and an offline replay buffer."""
When both buffers are present, each batch is constructed by sampling
``ceil(batch_size * online_ratio)`` from the online buffer and the
remainder from the offline buffer, then concatenating.
This mixer assumes both online and offline buffers are present.
"""
def __init__( def __init__(
self, self,
+1 -6
View File
@@ -31,10 +31,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
# TODO: Make `TrainPipelineConfig.dataset` optional # TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`, # Algorithm config.
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
# default hyperparameters. The top-level `policy` is injected into
# ``algorithm.policy_config`` at validation time.
algorithm: RLAlgorithmConfig | None = None algorithm: RLAlgorithmConfig | None = None
# Data mixer strategy name. Currently supports "online_offline". # Data mixer strategy name. Currently supports "online_offline".
@@ -48,7 +45,5 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
if self.algorithm is None: if self.algorithm is None:
self.algorithm = make_algorithm_config("sac") self.algorithm = make_algorithm_config("sac")
# The pipeline owns the policy config; inject it so the algorithm can
# introspect policy architecture (e.g. ``num_discrete_actions``).
if getattr(self.algorithm, "policy_config", None) is None: if getattr(self.algorithm, "policy_config", None) is None:
self.algorithm.policy_config = self.policy self.algorithm.policy_config = self.policy
+1 -5
View File
@@ -73,11 +73,7 @@ class RLTrainer:
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType: def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
"""Apply policy preprocessing to RL observations only. """Apply policy preprocessing to RL observations only."""
This mirrors the pre-refactor SAC learner behavior where actions are left
unchanged and only state/next_state observations are normalized.
"""
observations = batch["state"] observations = batch["state"]
next_observations = batch["next_state"] next_observations = batch["next_state"]
batch["state"] = preprocessor.process_observation(observations) batch["state"] = preprocessor.process_observation(observations)
-5
View File
@@ -303,11 +303,6 @@ def test_end_to_end_parameters_flow(cfg, data_size):
assert torch.allclose(received_params[key], input_params[key]) assert torch.allclose(received_params[key], input_params[key])
# ---------------------------------------------------------------------------
# Regression test: learner algorithm integration (no gRPC required)
# ---------------------------------------------------------------------------
def test_learner_algorithm_wiring(): def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config, """Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers_and_scheduler() creates the right optimizers, update() works, and make_optimizers_and_scheduler() creates the right optimizers, update() works, and