mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Merge remote-tracking branch 'origin/main' into user/khalil-meftah/2026-02-16-rl-stack-refactor
# Conflicts: # src/lerobot/policies/__init__.py # src/lerobot/rl/actor.py
This commit is contained in:
@@ -17,9 +17,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_interpolator import ActionInterpolator
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.action_interpolator import ActionInterpolator
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def test_rtc_config_default_initialization():
|
||||
"""Test RTCConfig initializes with default values."""
|
||||
config = RTCConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
|
||||
@@ -22,7 +22,7 @@ from lerobot.configs.types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor import TransitionKey, batch_to_transition, create_transition
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
@@ -52,6 +52,9 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
|
||||
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
||||
RTCProcessor = _rtc_mod.RTCProcessor
|
||||
|
||||
_rtc_relative_mod = _import_rtc_module("lerobot.policies.rtc.relative", "relative.py")
|
||||
reanchor_relative_rtc_prefix = _rtc_relative_mod.reanchor_relative_rtc_prefix
|
||||
|
||||
ACTION_DIM = 6
|
||||
CHUNK_SIZE = 50
|
||||
EXECUTION_HORIZON = 10
|
||||
@@ -187,7 +190,7 @@ class TestRTCDenoiseWithRelativeLeftovers:
|
||||
|
||||
|
||||
class TestFullPipelineRelativeRTC:
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow."""
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching lerobot-rollout flow."""
|
||||
|
||||
def test_preprocessor_caches_state_for_postprocessor(self):
|
||||
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
|
||||
@@ -218,7 +221,9 @@ class TestFullPipelineRelativeRTC:
|
||||
|
||||
def test_roundtrip_with_identity_normalization(self):
|
||||
"""Actions → relative → normalize → [model] → unnormalize → absolute should recover originals.
|
||||
Using mean=0, std=1 normalization (identity)."""
|
||||
|
||||
Using mean=0, std=1 normalization (identity).
|
||||
"""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
@@ -240,7 +245,7 @@ class TestFullPipelineRelativeRTC:
|
||||
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_eval_loop_simulation(self):
|
||||
"""Simulate the eval_with_real_robot.py loop with relative actions.
|
||||
"""Simulate the lerobot-rollout loop with relative actions.
|
||||
|
||||
Iteration 1: No leftovers → model generates relative actions → store for RTC
|
||||
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
|
||||
@@ -400,13 +405,113 @@ class TestStateRebasingApproximation:
|
||||
assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}"
|
||||
|
||||
|
||||
class TestRTCReanchoringWithStateNormalizer:
|
||||
"""RTC re-anchoring under non-identity OBS_STATE normalization."""
|
||||
|
||||
@staticmethod
|
||||
def _build_normalizer_with_state_stats():
|
||||
"""Build a relative-action preprocessor with non-trivial OBS_STATE stats."""
|
||||
features = {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(ACTION_DIM,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
stats = {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(ACTION_DIM).numpy(),
|
||||
"std": (0.5 * torch.ones(ACTION_DIM)).numpy(),
|
||||
},
|
||||
OBS_STATE: {
|
||||
"mean": (5.0 * torch.ones(ACTION_DIM)).numpy(),
|
||||
"std": (2.0 * torch.ones(ACTION_DIM)).numpy(),
|
||||
},
|
||||
}
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
return relative_step, normalizer
|
||||
|
||||
def test_reanchor_with_raw_state_matches_normalize_of_absolute_minus_state(self):
|
||||
"""Reanchoring with the raw cached state yields ``normalize(prev_actions_absolute - raw_state)``."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
|
||||
prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5)
|
||||
|
||||
result = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=relative_step.get_cached_state(),
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
|
||||
expected_relative = to_relative_actions(prev_actions_absolute, raw_state, [True] * ACTION_DIM)
|
||||
expected = normalizer(create_transition(action=expected_relative))[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_reanchor_with_normalized_state_produces_wrong_result(self):
|
||||
"""Reanchoring with raw vs. normalized state produces meaningfully different outputs."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
|
||||
normalized_obs = normalizer(batch_to_transition({OBS_STATE: raw_state.clone()}))
|
||||
normalized_state = normalized_obs[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
assert not torch.allclose(normalized_state, raw_state)
|
||||
|
||||
prev_actions_absolute = torch.tensor([[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]] * 5)
|
||||
|
||||
result_raw = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=raw_state,
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
result_normalized = reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_actions_absolute,
|
||||
current_state=normalized_state,
|
||||
relative_step=relative_step,
|
||||
normalizer_step=normalizer,
|
||||
policy_device="cpu",
|
||||
)
|
||||
|
||||
max_abs_diff = (result_raw - result_normalized).abs().max()
|
||||
assert max_abs_diff > 0.5, (
|
||||
f"Raw and normalized state produced near-identical outputs (max diff {max_abs_diff:.4f}); "
|
||||
"OBS_STATE stats are too close to identity to be sensitive."
|
||||
)
|
||||
|
||||
def test_engine_pipeline_cached_state_is_raw_after_full_preprocess(self):
|
||||
"""``get_cached_state()`` returns raw OBS_STATE after the full preprocessor pipeline runs."""
|
||||
relative_step, normalizer = self._build_normalizer_with_state_stats()
|
||||
|
||||
raw_state = torch.tensor([[7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
|
||||
|
||||
transition = batch_to_transition({OBS_STATE: raw_state.clone()})
|
||||
transition = relative_step(transition)
|
||||
preprocessed = normalizer(transition)
|
||||
|
||||
cached = relative_step.get_cached_state()
|
||||
torch.testing.assert_close(cached, raw_state, atol=1e-6, rtol=1e-6)
|
||||
|
||||
post_normalize_state = preprocessed[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
assert not torch.allclose(cached, post_normalize_state, atol=1e-3)
|
||||
|
||||
|
||||
def _detect_relative_actions(preprocessor) -> bool:
|
||||
"""Mirror of the helper in eval_with_real_robot.py for testing without importing it."""
|
||||
"""Mirror of the helper in lerobot-rollout for testing without importing it."""
|
||||
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
|
||||
|
||||
|
||||
class TestDetectRelativeActions:
|
||||
"""Test the _detect_relative_actions helper logic used by eval_with_real_robot.py."""
|
||||
"""Test the _detect_relative_actions helper logic used by lerobot-rollout."""
|
||||
|
||||
def test_detects_enabled_relative_step(self):
|
||||
class FakePipeline:
|
||||
|
||||
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
|
||||
Reference in New Issue
Block a user