mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
Reset rollout state after robot episode end
This commit is contained in:
@@ -146,6 +146,9 @@ class RelativeActionsProcessorStep(ProcessorStep):
|
||||
"""Return the cached ``observation.state`` used as the reference point for relative/absolute action conversions."""
|
||||
return self._last_state
|
||||
|
||||
def reset(self) -> None:
|
||||
self._last_state = None
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
|
||||
@@ -36,6 +36,7 @@ class ThreadSafeRobot:
|
||||
def __init__(self, robot: Robot) -> None:
|
||||
self._robot = robot
|
||||
self._lock = Lock()
|
||||
self._last_action_response: Any | None = None
|
||||
|
||||
# -- Lock-protected I/O --------------------------------------------------
|
||||
|
||||
@@ -45,7 +46,15 @@ class ThreadSafeRobot:
|
||||
|
||||
def send_action(self, action: dict[str, Any] | Any) -> Any:
|
||||
with self._lock:
|
||||
return self._robot.send_action(action)
|
||||
response = self._robot.send_action(action)
|
||||
self._last_action_response = getattr(self._robot, "last_action_response", response)
|
||||
return response
|
||||
|
||||
def pop_last_action_response(self) -> Any | None:
|
||||
with self._lock:
|
||||
response = self._last_action_response
|
||||
self._last_action_response = None
|
||||
return response
|
||||
|
||||
# -- Read-only proxies (no lock needed) -----------------------------------
|
||||
|
||||
|
||||
@@ -67,6 +67,8 @@ class BaseStrategy(RolloutStrategy):
|
||||
|
||||
action_dict = send_next_action(obs_processed, obs, ctx, interpolator)
|
||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||
if action_dict is not None:
|
||||
self._reset_inference_after_robot_episode_done(ctx)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
|
||||
@@ -116,6 +116,22 @@ class RolloutStrategy(abc.ABC):
|
||||
engine.resume()
|
||||
return False
|
||||
|
||||
def _reset_inference_after_robot_episode_done(self, ctx: RolloutContext) -> None:
|
||||
"""Reset rollout-side episode state when a robot backend reports an environment reset."""
|
||||
response = ctx.hardware.robot_wrapper.pop_last_action_response()
|
||||
if not isinstance(response, dict) or not response.get("done"):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Robot reported episode done (success=%s); resetting rollout inference state",
|
||||
response.get("success"),
|
||||
)
|
||||
if self._engine is not None:
|
||||
self._engine.reset()
|
||||
if self._interpolator is not None:
|
||||
self._interpolator.reset()
|
||||
self._cached_obs_processed = None
|
||||
|
||||
def _teardown_hardware(self, hw: HardwareContext, return_to_initial_position: bool = True) -> None:
|
||||
"""Stop the inference engine, optionally return robot to initial position, and disconnect hardware."""
|
||||
if self._engine is not None:
|
||||
|
||||
@@ -171,6 +171,23 @@ def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_relative_actions_processor_reset_clears_cached_state():
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
transition = batch_to_transition(
|
||||
{
|
||||
OBS_STATE: torch.tensor([[1.0, 2.0]]),
|
||||
ACTION: torch.tensor([[[1.5, 1.0]]]),
|
||||
}
|
||||
)
|
||||
|
||||
relative_step(transition)
|
||||
assert relative_step.get_cached_state() is not None
|
||||
|
||||
relative_step.reset()
|
||||
|
||||
assert relative_step.get_cached_state() is None
|
||||
|
||||
|
||||
def test_normalized_relative_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized relative actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
@@ -175,6 +175,26 @@ def test_thread_safe_robot_delegates():
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_thread_safe_robot_tracks_action_response_metadata():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
|
||||
class MetadataRobot:
|
||||
def get_observation(self):
|
||||
return {}
|
||||
|
||||
def send_action(self, action):
|
||||
self.last_action_response = {"action": action, "done": True, "success": False}
|
||||
return action
|
||||
|
||||
robot = MetadataRobot()
|
||||
wrapper = ThreadSafeRobot(robot)
|
||||
|
||||
action = {"motor_1.pos": 0.0, "motor_2.pos": 1.0, "motor_3.pos": 2.0}
|
||||
assert wrapper.send_action(action) == action
|
||||
assert wrapper.pop_last_action_response() == {"action": action, "done": True, "success": False}
|
||||
assert wrapper.pop_last_action_response() is None
|
||||
|
||||
|
||||
def test_thread_safe_robot_properties():
|
||||
from lerobot.rollout.robot_wrapper import ThreadSafeRobot
|
||||
from tests.mocks.mock_robot import MockRobot, MockRobotConfig
|
||||
@@ -192,6 +212,42 @@ def test_thread_safe_robot_properties():
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def test_base_strategy_resets_inference_when_robot_reports_episode_done():
|
||||
from lerobot.rollout import BaseStrategy, BaseStrategyConfig
|
||||
|
||||
strategy = BaseStrategy(BaseStrategyConfig())
|
||||
strategy._engine = MagicMock()
|
||||
strategy._interpolator = MagicMock()
|
||||
strategy._cached_obs_processed = {"stale": True}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": True, "success": False}
|
||||
|
||||
strategy._reset_inference_after_robot_episode_done(ctx)
|
||||
|
||||
strategy._engine.reset.assert_called_once()
|
||||
strategy._interpolator.reset.assert_called_once()
|
||||
assert strategy._cached_obs_processed is None
|
||||
|
||||
|
||||
def test_base_strategy_ignores_action_responses_without_episode_done():
|
||||
from lerobot.rollout import BaseStrategy, BaseStrategyConfig
|
||||
|
||||
strategy = BaseStrategy(BaseStrategyConfig())
|
||||
strategy._engine = MagicMock()
|
||||
strategy._interpolator = MagicMock()
|
||||
strategy._cached_obs_processed = {"cached": True}
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.hardware.robot_wrapper.pop_last_action_response.return_value = {"done": False}
|
||||
|
||||
strategy._reset_inference_after_robot_episode_done(ctx)
|
||||
|
||||
strategy._engine.reset.assert_not_called()
|
||||
strategy._interpolator.reset.assert_not_called()
|
||||
assert strategy._cached_obs_processed == {"cached": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user