mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
refactor: simplify docstrings for clarity and conciseness across multiple files
This commit is contained in:
@@ -36,15 +36,6 @@ DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
||||
class GaussianActorPolicy(
|
||||
PreTrainedPolicy,
|
||||
):
|
||||
"""Gaussian actor + observation encoder.
|
||||
|
||||
Policy-side ``nn.Module`` used by SAC and related maximum-entropy continuous
|
||||
control algorithms. It owns the actor network (``Policy``) and the observation
|
||||
encoder (``GaussianActorObservationEncoder``); the critics, temperature, and
|
||||
Bellman-update logic live on the algorithm side
|
||||
(see ``lerobot.rl.algorithms.sac``).
|
||||
"""
|
||||
|
||||
config_class = GaussianActorConfig
|
||||
name = "gaussian_actor"
|
||||
|
||||
|
||||
@@ -291,11 +291,7 @@ def act_with_policy(
|
||||
with policy_timer:
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
||||
# `select_action` concatenates an argmax index in env space at the last dim;
|
||||
# action stats cover the continuous dims only, so feeding the full vector to
|
||||
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
||||
# index by treating it as a normalized value.
|
||||
# Unnormalize only the continuous part.
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||
discrete_action = action[..., -1:].to(
|
||||
@@ -346,9 +342,6 @@ def act_with_policy(
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
# Forward the intervention flag so the learner can route this transition
|
||||
# into the offline replay buffer (see `process_transitions` in learner.py).
|
||||
# Use the plain string key so the payload survives torch.load(weights_only=True).
|
||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
|
||||
@@ -21,7 +21,7 @@ from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||
"""Instantiate an :class:`RLAlgorithmConfig` from its registered type name.
|
||||
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
|
||||
|
||||
Args:
|
||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||
|
||||
@@ -21,11 +21,7 @@ from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies.
|
||||
|
||||
Subclasses must implement ``sample(batch_size)`` and may override
|
||||
``get_iterator`` for specialised iteration.
|
||||
"""
|
||||
"""Abstract interface for all data mixing strategies."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
@@ -38,25 +34,13 @@ class DataMixer(abc.ABC):
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches.
|
||||
|
||||
The default implementation repeatedly calls ``self.sample()``.
|
||||
Subclasses with underlying buffer iterators (async prefetch)
|
||||
should override this for better throughput.
|
||||
"""
|
||||
"""Infinite iterator that yields batches."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an optional offline replay buffer.
|
||||
|
||||
When both buffers are present, each batch is constructed by sampling
|
||||
``ceil(batch_size * online_ratio)`` from the online buffer and the
|
||||
remainder from the offline buffer, then concatenating.
|
||||
|
||||
This mixer assumes both online and offline buffers are present.
|
||||
"""
|
||||
"""Mixes transitions from an online and an offline replay buffer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -31,10 +31,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`,
|
||||
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
|
||||
# default hyperparameters. The top-level `policy` is injected into
|
||||
# ``algorithm.policy_config`` at validation time.
|
||||
# Algorithm config.
|
||||
algorithm: RLAlgorithmConfig | None = None
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline".
|
||||
@@ -48,7 +45,5 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
if self.algorithm is None:
|
||||
self.algorithm = make_algorithm_config("sac")
|
||||
|
||||
# The pipeline owns the policy config; inject it so the algorithm can
|
||||
# introspect policy architecture (e.g. ``num_discrete_actions``).
|
||||
if getattr(self.algorithm, "policy_config", None) is None:
|
||||
self.algorithm.policy_config = self.policy
|
||||
|
||||
@@ -73,11 +73,7 @@ class RLTrainer:
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
||||
"""Apply policy preprocessing to RL observations only.
|
||||
|
||||
This mirrors the pre-refactor SAC learner behavior where actions are left
|
||||
unchanged and only state/next_state observations are normalized.
|
||||
"""
|
||||
"""Apply policy preprocessing to RL observations only."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
batch["state"] = preprocessor.process_observation(observations)
|
||||
|
||||
@@ -303,11 +303,6 @@ def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression test: learner algorithm integration (no gRPC required)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_learner_algorithm_wiring():
|
||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
|
||||
|
||||
Reference in New Issue
Block a user