mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
refactor: simplify docstrings for clarity and conciseness across multiple files
This commit is contained in:
@@ -36,15 +36,6 @@ DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension
|
|||||||
class GaussianActorPolicy(
|
class GaussianActorPolicy(
|
||||||
PreTrainedPolicy,
|
PreTrainedPolicy,
|
||||||
):
|
):
|
||||||
"""Gaussian actor + observation encoder.
|
|
||||||
|
|
||||||
Policy-side ``nn.Module`` used by SAC and related maximum-entropy continuous
|
|
||||||
control algorithms. It owns the actor network (``Policy``) and the observation
|
|
||||||
encoder (``GaussianActorObservationEncoder``); the critics, temperature, and
|
|
||||||
Bellman-update logic live on the algorithm side
|
|
||||||
(see ``lerobot.rl.algorithms.sac``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = GaussianActorConfig
|
config_class = GaussianActorConfig
|
||||||
name = "gaussian_actor"
|
name = "gaussian_actor"
|
||||||
|
|
||||||
|
|||||||
@@ -291,11 +291,7 @@ def act_with_policy(
|
|||||||
with policy_timer:
|
with policy_timer:
|
||||||
normalized_observation = preprocessor.process_observation(observation)
|
normalized_observation = preprocessor.process_observation(observation)
|
||||||
action = policy.select_action(batch=normalized_observation)
|
action = policy.select_action(batch=normalized_observation)
|
||||||
# Unnormalize only the continuous part. When `num_discrete_actions` is set,
|
# Unnormalize only the continuous part.
|
||||||
# `select_action` concatenates an argmax index in env space at the last dim;
|
|
||||||
# action stats cover the continuous dims only, so feeding the full vector to
|
|
||||||
# the unnormalizer would shape-mismatch and would also corrupt the discrete
|
|
||||||
# index by treating it as a normalized value.
|
|
||||||
if cfg.policy.num_discrete_actions is not None:
|
if cfg.policy.num_discrete_actions is not None:
|
||||||
continuous_action = postprocessor.process_action(action[..., :-1])
|
continuous_action = postprocessor.process_action(action[..., :-1])
|
||||||
discrete_action = action[..., -1:].to(
|
discrete_action = action[..., -1:].to(
|
||||||
@@ -346,9 +342,6 @@ def act_with_policy(
|
|||||||
"discrete_penalty": torch.tensor(
|
"discrete_penalty": torch.tensor(
|
||||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||||
),
|
),
|
||||||
# Forward the intervention flag so the learner can route this transition
|
|
||||||
# into the offline replay buffer (see `process_transitions` in learner.py).
|
|
||||||
# Use the plain string key so the payload survives torch.load(weights_only=True).
|
|
||||||
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
TeleopEvents.IS_INTERVENTION.value: is_intervention,
|
||||||
}
|
}
|
||||||
# Create transition for learner (convert to old format)
|
# Create transition for learner (convert to old format)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from lerobot.rl.algorithms.configs import RLAlgorithmConfig
|
|||||||
|
|
||||||
|
|
||||||
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
def make_algorithm_config(algorithm_type: str, **kwargs) -> RLAlgorithmConfig:
|
||||||
"""Instantiate an :class:`RLAlgorithmConfig` from its registered type name.
|
"""Instantiate an `RLAlgorithmConfig` from its registered type name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
algorithm_type: Registry key of the algorithm (e.g. ``"sac"``).
|
||||||
|
|||||||
@@ -21,11 +21,7 @@ from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
|||||||
|
|
||||||
|
|
||||||
class DataMixer(abc.ABC):
|
class DataMixer(abc.ABC):
|
||||||
"""Abstract interface for all data mixing strategies.
|
"""Abstract interface for all data mixing strategies."""
|
||||||
|
|
||||||
Subclasses must implement ``sample(batch_size)`` and may override
|
|
||||||
``get_iterator`` for specialised iteration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def sample(self, batch_size: int) -> BatchType:
|
def sample(self, batch_size: int) -> BatchType:
|
||||||
@@ -38,25 +34,13 @@ class DataMixer(abc.ABC):
|
|||||||
async_prefetch: bool = True,
|
async_prefetch: bool = True,
|
||||||
queue_size: int = 2,
|
queue_size: int = 2,
|
||||||
):
|
):
|
||||||
"""Infinite iterator that yields batches.
|
"""Infinite iterator that yields batches."""
|
||||||
|
|
||||||
The default implementation repeatedly calls ``self.sample()``.
|
|
||||||
Subclasses with underlying buffer iterators (async prefetch)
|
|
||||||
should override this for better throughput.
|
|
||||||
"""
|
|
||||||
while True:
|
while True:
|
||||||
yield self.sample(batch_size)
|
yield self.sample(batch_size)
|
||||||
|
|
||||||
|
|
||||||
class OnlineOfflineMixer(DataMixer):
|
class OnlineOfflineMixer(DataMixer):
|
||||||
"""Mixes transitions from an online and an optional offline replay buffer.
|
"""Mixes transitions from an online and an offline replay buffer."""
|
||||||
|
|
||||||
When both buffers are present, each batch is constructed by sampling
|
|
||||||
``ceil(batch_size * online_ratio)`` from the online buffer and the
|
|
||||||
remainder from the offline buffer, then concatenating.
|
|
||||||
|
|
||||||
This mixer assumes both online and offline buffers are present.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -31,10 +31,7 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
|||||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||||
|
|
||||||
# Algorithm config (a `draccus.ChoiceRegistry` subclass selected by `type`,
|
# Algorithm config.
|
||||||
# e.g. ``"type": "sac"``). When omitted, defaults to a SAC config with
|
|
||||||
# default hyperparameters. The top-level `policy` is injected into
|
|
||||||
# ``algorithm.policy_config`` at validation time.
|
|
||||||
algorithm: RLAlgorithmConfig | None = None
|
algorithm: RLAlgorithmConfig | None = None
|
||||||
|
|
||||||
# Data mixer strategy name. Currently supports "online_offline".
|
# Data mixer strategy name. Currently supports "online_offline".
|
||||||
@@ -48,7 +45,5 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
|||||||
if self.algorithm is None:
|
if self.algorithm is None:
|
||||||
self.algorithm = make_algorithm_config("sac")
|
self.algorithm = make_algorithm_config("sac")
|
||||||
|
|
||||||
# The pipeline owns the policy config; inject it so the algorithm can
|
|
||||||
# introspect policy architecture (e.g. ``num_discrete_actions``).
|
|
||||||
if getattr(self.algorithm, "policy_config", None) is None:
|
if getattr(self.algorithm, "policy_config", None) is None:
|
||||||
self.algorithm.policy_config = self.policy
|
self.algorithm.policy_config = self.policy
|
||||||
|
|||||||
@@ -73,11 +73,7 @@ class RLTrainer:
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
def preprocess_rl_batch(preprocessor: Any, batch: BatchType) -> BatchType:
|
||||||
"""Apply policy preprocessing to RL observations only.
|
"""Apply policy preprocessing to RL observations only."""
|
||||||
|
|
||||||
This mirrors the pre-refactor SAC learner behavior where actions are left
|
|
||||||
unchanged and only state/next_state observations are normalized.
|
|
||||||
"""
|
|
||||||
observations = batch["state"]
|
observations = batch["state"]
|
||||||
next_observations = batch["next_state"]
|
next_observations = batch["next_state"]
|
||||||
batch["state"] = preprocessor.process_observation(observations)
|
batch["state"] = preprocessor.process_observation(observations)
|
||||||
|
|||||||
@@ -303,11 +303,6 @@ def test_end_to_end_parameters_flow(cfg, data_size):
|
|||||||
assert torch.allclose(received_params[key], input_params[key])
|
assert torch.allclose(received_params[key], input_params[key])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Regression test: learner algorithm integration (no gRPC required)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_learner_algorithm_wiring():
|
def test_learner_algorithm_wiring():
|
||||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||||
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
|
make_optimizers_and_scheduler() creates the right optimizers, update() works, and
|
||||||
|
|||||||
Reference in New Issue
Block a user