From 31f7979498397f7591653ed04df830ec5cadcc0d Mon Sep 17 00:00:00 2001 From: Andy Wrenn Date: Sun, 21 Jun 2026 08:26:48 -0700 Subject: [PATCH] Revert "Reset rollout state after robot episode end" This reverts commit 1322f45aec088d3ca346640d995d28edcf71d00f. --- .../processor/relative_action_processor.py | 3 - src/lerobot/rollout/robot_wrapper.py | 11 +--- src/lerobot/rollout/strategies/base.py | 2 - src/lerobot/rollout/strategies/core.py | 16 ------ tests/policies/test_relative_actions.py | 17 ------ tests/test_rollout.py | 56 ------------------- 6 files changed, 1 insertion(+), 104 deletions(-) diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py index 577080f76..5b1039291 100644 --- a/src/lerobot/processor/relative_action_processor.py +++ b/src/lerobot/processor/relative_action_processor.py @@ -146,9 +146,6 @@ class RelativeActionsProcessorStep(ProcessorStep): """Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions.""" return self._last_state - def reset(self) -> None: - self._last_state = None - def get_config(self) -> dict[str, Any]: return { "enabled": self.enabled, diff --git a/src/lerobot/rollout/robot_wrapper.py b/src/lerobot/rollout/robot_wrapper.py index cf33d4600..44f744812 100644 --- a/src/lerobot/rollout/robot_wrapper.py +++ b/src/lerobot/rollout/robot_wrapper.py @@ -36,7 +36,6 @@ class ThreadSafeRobot: def __init__(self, robot: Robot) -> None: self._robot = robot self._lock = Lock() - self._last_action_response: Any | None = None # -- Lock-protected I/O -------------------------------------------------- @@ -46,15 +45,7 @@ class ThreadSafeRobot: def send_action(self, action: dict[str, Any] | Any) -> Any: with self._lock: - response = self._robot.send_action(action) - self._last_action_response = getattr(self._robot, "last_action_response", response) - return response - - def pop_last_action_response(self) -> Any | None: - with self._lock: - response = self._last_action_response - self._last_action_response = None - return response + return self._robot.send_action(action) # -- Read-only proxies (no lock needed) ----------------------------------- diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index a53a31809..e47b65209 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -67,8 +67,6 @@ class BaseStrategy(RolloutStrategy): action_dict = send_next_action(obs_processed, obs, ctx, interpolator) self._log_telemetry(obs_processed, action_dict, ctx.runtime) - if action_dict is not None: - self._reset_inference_after_robot_episode_done(ctx) dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index f9d525666..9c897522f 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -116,22 +116,6 @@ class RolloutStrategy(abc.ABC): engine.resume() return False - def _reset_inference_after_robot_episode_done(self, ctx: RolloutContext) -> None: - """Reset rollout-side episode state when a robot backend reports an environment reset.""" - response = ctx.hardware.robot_wrapper.pop_last_action_response() - if not isinstance(response, dict) or not response.get("done"): - return - - logger.info( - "Robot reported episode done (success=%s); resetting rollout inference state", - response.get("success"), - ) - if self._engine is not None: - self._engine.reset() - if self._interpolator is not None: - self._interpolator.reset() - self._cached_obs_processed = None - def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None: """Stop the inference engine, optionally return robot to initial position, and disconnect hardware.""" if self._engine is not None: diff --git a/tests/policies/test_relative_actions.py b/tests/policies/test_relative_actions.py index 83c1748a5..15ef0a31b 100644 --- a/tests/policies/test_relative_actions.py +++ b/tests/policies/test_relative_actions.py @@ -171,23 +171,6 @@ def test_full_pipeline_roundtrip(dataset, action_dim): torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4) -def test_relative_actions_processor_reset_clears_cached_state(): - relative_step = RelativeActionsProcessorStep(enabled=True) - transition = batch_to_transition( - { - OBS_STATE: torch.tensor([[1.0, 2.0]]), - ACTION: torch.tensor([[[1.5, 1.0]]]), - } - ) - - relative_step(transition) - assert relative_step.get_cached_state() is not None - - relative_step.reset() - - assert relative_step.get_cached_state() is None - - def test_normalized_relative_values_are_reasonable(dataset, action_dim): """With correct chunk stats, normalized relative actions should be in a reasonable range.""" action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) diff --git a/tests/test_rollout.py b/tests/test_rollout.py index 76f0b32a2..85a29ff4c 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -177,26 +177,6 @@ def test_thread_safe_robot_delegates(): robot.disconnect() -def test_thread_safe_robot_tracks_action_response_metadata(): - from lerobot.rollout.robot_wrapper import ThreadSafeRobot - - class MetadataRobot: - def get_observation(self): - return {} - - def send_action(self, action): - self.last_action_response = {"action": action, "done": True, "success": False} - return action - - robot = MetadataRobot() - wrapper = ThreadSafeRobot(robot) - - action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0} - assert wrapper.send_action(action) == action - assert wrapper.pop_last_action_response() == {"action": action, "done": True, "success": False} - assert wrapper.pop_last_action_response() is None - - def test_thread_safe_robot_properties(): from lerobot.rollout.robot_wrapper import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig @@ -214,42 +194,6 @@ def test_thread_safe_robot_properties(): robot.disconnect() -def test_base_strategy_resets_inference_when_robot_reports_episode_done(): - from lerobot.rollout import BaseStrategy, BaseStrategyConfig - - strategy = BaseStrategy(BaseStrategyConfig()) - strategy._engine = MagicMock() - strategy._interpolator = MagicMock() - strategy._cached_obs_processed = {"stale": True} - - ctx = MagicMock() - ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": True, "success": False} - - strategy._reset_inference_after_robot_episode_done(ctx) - - strategy._engine.reset.assert_called_once() - strategy._interpolator.reset.assert_called_once() - assert strategy._cached_obs_processed is None - - -def test_base_strategy_ignores_action_responses_without_episode_done(): - from lerobot.rollout import BaseStrategy, BaseStrategyConfig - - strategy = BaseStrategy(BaseStrategyConfig()) - strategy._engine = MagicMock() - strategy._interpolator = MagicMock() - strategy._cached_obs_processed = {"cached": True} - - ctx = MagicMock() - ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": False} - - strategy._reset_inference_after_robot_episode_done(ctx) - - strategy._engine.reset.assert_not_called() - strategy._interpolator.reset.assert_not_called() - assert strategy._cached_obs_processed == {"cached": True} - - # --------------------------------------------------------------------------- # Strategy factory # ---------------------------------------------------------------------------