Revert "Reset rollout state after robot episode end"

This reverts commit 1322f45aec.
This commit is contained in:
Andy Wrenn
2026-06-21 08:26:48 -07:00
parent b8dcc51f35
commit 31f7979498
6 changed files with 1 additions and 104 deletions
@@ -146,9 +146,6 @@ class RelativeActionsProcessorStep(ProcessorStep):
"""Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions."""
return self._last_state
def reset(self) -> None:
self._last_state = None
def get_config(self) -> dict[str, Any]:
return {
"enabled": self.enabled,
+1 -10
View File
@@ -36,7 +36,6 @@ class ThreadSafeRobot:
def __init__(self, robot: Robot) -> None:
self._robot = robot
self._lock = Lock()
self._last_action_response: Any | None = None
# -- Lock-protected I/O --------------------------------------------------
@@ -46,15 +45,7 @@ class ThreadSafeRobot:
def send_action(self, action: dict[str, Any] | Any) -> Any:
with self._lock:
response = self._robot.send_action(action)
self._last_action_response = getattr(self._robot, "last_action_response", response)
return response
def pop_last_action_response(self) -> Any | None:
with self._lock:
response = self._last_action_response
self._last_action_response = None
return response
return self._robot.send_action(action)
# -- Read-only proxies (no lock needed) -----------------------------------
-2
View File
@@ -67,8 +67,6 @@ class BaseStrategy(RolloutStrategy):
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
if action_dict is not None:
self._reset_inference_after_robot_episode_done(ctx)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
-16
View File
@@ -116,22 +116,6 @@ class RolloutStrategy(abc.ABC):
engine.resume()
return False
def _reset_inference_after_robot_episode_done(self, ctx: RolloutContext) -> None:
"""Reset rollout-side episode state when a robot backend reports an environment reset."""
response = ctx.hardware.robot_wrapper.pop_last_action_response()
if not isinstance(response, dict) or not response.get("done"):
return
logger.info(
"Robot reported episode done (success=%s); resetting rollout inference state",
response.get("success"),
)
if self._engine is not None:
self._engine.reset()
if self._interpolator is not None:
self._interpolator.reset()
self._cached_obs_processed = None
def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None:
"""Stop the inference engine, optionally return robot to initial position, and disconnect hardware."""
if self._engine is not None:
-17
View File
@@ -171,23 +171,6 @@ def test_full_pipeline_roundtrip(dataset, action_dim):
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
def test_relative_actions_processor_reset_clears_cached_state():
relative_step = RelativeActionsProcessorStep(enabled=True)
transition = batch_to_transition(
{
OBS_STATE: torch.tensor([[1.0, 2.0]]),
ACTION: torch.tensor([[[1.5, 1.0]]]),
}
)
relative_step(transition)
assert relative_step.get_cached_state() is not None
relative_step.reset()
assert relative_step.get_cached_state() is None
def test_normalized_relative_values_are_reasonable(dataset, action_dim):
"""With correct chunk stats, normalized relative actions should be in a reasonable range."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
-56
View File
@@ -177,26 +177,6 @@ def test_thread_safe_robot_delegates():
robot.disconnect()
def test_thread_safe_robot_tracks_action_response_metadata():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
class MetadataRobot:
def get_observation(self):
return {}
def send_action(self, action):
self.last_action_response = {"action": action, "done": True, "success": False}
return action
robot = MetadataRobot()
wrapper = ThreadSafeRobot(robot)
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
assert wrapper.send_action(action) == action
assert wrapper.pop_last_action_response() == {"action": action, "done": True, "success": False}
assert wrapper.pop_last_action_response() is None
def test_thread_safe_robot_properties():
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
@@ -214,42 +194,6 @@ def test_thread_safe_robot_properties():
robot.disconnect()
def test_base_strategy_resets_inference_when_robot_reports_episode_done():
from lerobot.rollout import BaseStrategy, BaseStrategyConfig
strategy = BaseStrategy(BaseStrategyConfig())
strategy._engine = MagicMock()
strategy._interpolator = MagicMock()
strategy._cached_obs_processed = {"stale": True}
ctx = MagicMock()
ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": True, "success": False}
strategy._reset_inference_after_robot_episode_done(ctx)
strategy._engine.reset.assert_called_once()
strategy._interpolator.reset.assert_called_once()
assert strategy._cached_obs_processed is None
def test_base_strategy_ignores_action_responses_without_episode_done():
from lerobot.rollout import BaseStrategy, BaseStrategyConfig
strategy = BaseStrategy(BaseStrategyConfig())
strategy._engine = MagicMock()
strategy._interpolator = MagicMock()
strategy._cached_obs_processed = {"cached": True}
ctx = MagicMock()
ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": False}
strategy._reset_inference_after_robot_episode_done(ctx)
strategy._engine.reset.assert_not_called()
strategy._interpolator.reset.assert_not_called()
assert strategy._cached_obs_processed == {"cached": True}
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------