test dagger

This commit is contained in:
Steven Palma
2026-04-17 16:46:38 +02:00
parent 35bb2c7459
commit a76874f35e
3 changed files with 35 additions and 19 deletions
+2 -2
View File
@@ -96,8 +96,8 @@ class DAggerKeyboardConfig:
"""
pause_resume: str = "space"
correction: str = "c"
upload: str = "h"
correction: str = "tab"
upload: str = "enter"
@dataclass
+11 -9
View File
@@ -236,15 +236,17 @@ def build_rollout_context(
# DAgger requires teleop with motor control capabilities (enable_torque,
# disable_torque, write_goal_positions).
if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
if missing:
teleop.disconnect()
raise ValueError(
f"DAgger strategy requires a teleoperator with motor control methods "
f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
)
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# if isinstance(cfg.strategy, DAggerStrategyConfig) and teleop is not None:
# required_teleop_methods = ("enable_torque", "disable_torque", "write_goal_positions")
# missing = [m for m in required_teleop_methods if not callable(getattr(teleop, m, None))]
# if missing:
# teleop.disconnect()
# raise ValueError(
# f"DAgger strategy requires a teleoperator with motor control methods "
# f"{required_teleop_methods}. '{type(teleop).__name__}' is missing: {missing}"
# )
# --- 4. Features + action-key reconciliation ---------------------
all_obs_features = robot.observation_features
+22 -8
View File
@@ -165,7 +165,8 @@ class DAggerEvents:
# Teleoperator helpers
# ---------------------------------------------------------------------------
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
def _teleop_smooth_move_to(
teleop: Teleoperator, target_pos: dict, duration_s: float = 2.0, fps: int = 50
) -> None:
@@ -375,7 +376,9 @@ class DAggerStrategy(RolloutStrategy):
engine.reset()
interpolator.reset()
events.reset()
teleop.disable_torque()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
@@ -469,7 +472,9 @@ class DAggerStrategy(RolloutStrategy):
finally:
engine.pause()
teleop.disable_torque()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
@@ -503,7 +508,9 @@ class DAggerStrategy(RolloutStrategy):
engine.reset()
interpolator.reset()
events.reset()
teleop.disable_torque()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
engine.resume()
last_action: dict[str, Any] | None = None
@@ -586,7 +593,9 @@ class DAggerStrategy(RolloutStrategy):
finally:
engine.pause()
teleop.disable_torque()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
with contextlib.suppress(Exception):
with self._episode_lock:
dataset.save_episode()
@@ -609,13 +618,18 @@ class DAggerStrategy(RolloutStrategy):
if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED:
engine.pause()
obs = robot.get_observation()
robot_pos = {
_robot_pos = {
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
}
_teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# _teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
elif new_phase == DAggerPhase.CORRECTING:
teleop.disable_torque()
# TODO(Steven): either enforce this (meaning all teleop must implement these methods) or
# user is responsible for moving the teleop to the same position as the robot when starting the correction.
# teleop.disable_torque()
pass
elif new_phase == DAggerPhase.AUTONOMOUS:
interpolator.reset()