diff --git a/src/lerobot/processor/relative_action_processor.py b/src/lerobot/processor/relative_action_processor.py index 5b1039291..577080f76 100644 --- a/src/lerobot/processor/relative_action_processor.py +++ b/src/lerobot/processor/relative_action_processor.py @@ -146,6 +146,9 @@ class RelativeActionsProcessorStep(ProcessorStep): """Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions.""" return self._last_state + def reset(self) -> None: + self._last_state = None + def get_config(self) -> dict[str, Any]: return { "enabled": self.enabled, diff --git a/src/lerobot/rollout/robot_wrapper.py b/src/lerobot/rollout/robot_wrapper.py index 44f744812..cf33d4600 100644 --- a/src/lerobot/rollout/robot_wrapper.py +++ b/src/lerobot/rollout/robot_wrapper.py @@ -36,6 +36,7 @@ class ThreadSafeRobot: def __init__(self, robot: Robot) -> None: self._robot = robot self._lock = Lock() + self._last_action_response: Any | None = None # -- Lock-protected I/O -------------------------------------------------- @@ -45,7 +46,15 @@ class ThreadSafeRobot: def send_action(self, action: dict[str, Any] | Any) -> Any: with self._lock: - return self._robot.send_action(action) + response = self._robot.send_action(action) + self._last_action_response = getattr(self._robot, "last_action_response", response) + return response + + def pop_last_action_response(self) -> Any | None: + with self._lock: + response = self._last_action_response + self._last_action_response = None + return response # -- Read-only proxies (no lock needed) ----------------------------------- diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index e47b65209..a53a31809 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -67,6 +67,8 @@ class BaseStrategy(RolloutStrategy): action_dict = send_next_action(obs_processed, obs, ctx, interpolator) self._log_telemetry(obs_processed, action_dict, ctx.runtime) + if action_dict is not None: + self._reset_inference_after_robot_episode_done(ctx) dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 9c897522f..f9d525666 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -116,6 +116,22 @@ class RolloutStrategy(abc.ABC): engine.resume() return False + def _reset_inference_after_robot_episode_done(self, ctx: RolloutContext) -> None: + """Reset rollout-side episode state when a robot backend reports an environment reset.""" + response = ctx.hardware.robot_wrapper.pop_last_action_response() + if not isinstance(response, dict) or not response.get("done"): + return + + logger.info( + "Robot reported episode done (success=%s); resetting rollout inference state", + response.get("success"), + ) + if self._engine is not None: + self._engine.reset() + if self._interpolator is not None: + self._interpolator.reset() + self._cached_obs_processed = None + def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None: """Stop the inference engine, optionally return robot to initial position, and disconnect hardware.""" if self._engine is not None: diff --git a/tests/policies/test_relative_actions.py b/tests/policies/test_relative_actions.py index 15ef0a31b..83c1748a5 100644 --- a/tests/policies/test_relative_actions.py +++ b/tests/policies/test_relative_actions.py @@ -171,6 +171,23 @@ def test_full_pipeline_roundtrip(dataset, action_dim): torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4) +def test_relative_actions_processor_reset_clears_cached_state(): + relative_step = RelativeActionsProcessorStep(enabled=True) + transition = batch_to_transition( + { + OBS_STATE: torch.tensor([[1.0, 2.0]]), + ACTION: torch.tensor([[[1.5, 1.0]]]), + } + ) + + relative_step(transition) + assert relative_step.get_cached_state() is not None + + relative_step.reset() + + assert relative_step.get_cached_state() is None + + def test_normalized_relative_values_are_reasonable(dataset, action_dim): """With correct chunk stats, normalized relative actions should be in a reasonable range.""" action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE) diff --git a/tests/test_rollout.py b/tests/test_rollout.py index 85a29ff4c..76f0b32a2 100644 --- a/tests/test_rollout.py +++ b/tests/test_rollout.py @@ -177,6 +177,26 @@ def test_thread_safe_robot_delegates(): robot.disconnect() +def test_thread_safe_robot_tracks_action_response_metadata(): + from lerobot.rollout.robot_wrapper import ThreadSafeRobot + + class MetadataRobot: + def get_observation(self): + return {} + + def send_action(self, action): + self.last_action_response = {"action": action, "done": True, "success": False} + return action + + robot = MetadataRobot() + wrapper = ThreadSafeRobot(robot) + + action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0} + assert wrapper.send_action(action) == action + assert wrapper.pop_last_action_response() == {"action": action, "done": True, "success": False} + assert wrapper.pop_last_action_response() is None + + def test_thread_safe_robot_properties(): from lerobot.rollout.robot_wrapper import ThreadSafeRobot from tests.mocks.mock_robot import MockRobot, MockRobotConfig @@ -194,6 +214,42 @@ def test_thread_safe_robot_properties(): robot.disconnect() +def test_base_strategy_resets_inference_when_robot_reports_episode_done(): + from lerobot.rollout import BaseStrategy, BaseStrategyConfig + + strategy = BaseStrategy(BaseStrategyConfig()) + strategy._engine = MagicMock() + strategy._interpolator = MagicMock() + strategy._cached_obs_processed = {"stale": True} + + ctx = MagicMock() + ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": True, "success": False} + + strategy._reset_inference_after_robot_episode_done(ctx) + + strategy._engine.reset.assert_called_once() + strategy._interpolator.reset.assert_called_once() + assert strategy._cached_obs_processed is None + + +def test_base_strategy_ignores_action_responses_without_episode_done(): + from lerobot.rollout import BaseStrategy, BaseStrategyConfig + + strategy = BaseStrategy(BaseStrategyConfig()) + strategy._engine = MagicMock() + strategy._interpolator = MagicMock() + strategy._cached_obs_processed = {"cached": True} + + ctx = MagicMock() + ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": False} + + strategy._reset_inference_after_robot_episode_done(ctx) + + strategy._engine.reset.assert_not_called() + strategy._interpolator.reset.assert_not_called() + assert strategy._cached_obs_processed == {"cached": True} + + # --------------------------------------------------------------------------- # Strategy factory # ---------------------------------------------------------------------------