some more iterations

This commit is contained in:
Steven Palma
2026-04-14 16:34:52 +02:00
parent f55782f9f7
commit 49f32b9796
9 changed files with 343 additions and 311 deletions
-1
View File
@@ -222,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
-1
View File
@@ -222,7 +222,6 @@ def main():
# Save episode
dataset.save_episode()
episode_idx += 1
finally:
# Clean up
log_say("Stop recording")
+63 -12
View File
@@ -42,12 +42,32 @@ from lerobot.robots import Robot, make_robot_from_config
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
from .configs import BaseStrategyConfig, RolloutConfig
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
from .robot_wrapper import ThreadSafeRobot
logger = logging.getLogger(__name__)
def _resolve_action_key_order(
policy_action_names: list[str] | None, dataset_action_names: list[str]
) -> list[str]:
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
if not policy_action_names:
return dataset_action_names
policy_action_names = list(policy_action_names)
if len(policy_action_names) != len(dataset_action_names):
logger.warning(
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
len(policy_action_names),
len(dataset_action_names),
)
return dataset_action_names
if set(dataset_action_names) != set(policy_action_names):
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
return dataset_action_names
return policy_action_names
@dataclass
class RolloutContext:
"""Bundle of shared resources passed to every rollout strategy.
@@ -69,6 +89,7 @@ class RolloutContext:
shutdown_event: Event = field(default_factory=Event)
dataset_features: dict = field(default_factory=dict)
action_keys: list[str] = field(default_factory=list)
ordered_action_keys: list[str] = field(default_factory=list)
hw_features: dict = field(default_factory=dict)
@@ -92,26 +113,37 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
# --- Policy ---
# Use cfg.policy directly (already loaded in RolloutConfig.__post_init__)
# instead of reloading from disk.
policy_config = cfg.policy
use_rtc = cfg.rtc.enabled
policy_class = get_policy_class(cfg.policy.type)
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
policy_class = get_policy_class(policy_config.type)
# Reload config from pretrained path for full model parameters
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
# Merge any CLI overrides from cfg.policy into full_config
for attr in ("device", "use_amp"):
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
cli_val = getattr(cfg.policy, attr)
if cli_val is not None:
setattr(full_config, attr, cli_val)
# Set compile_model for pi0/pi05
if hasattr(policy_config, "compile_model"):
policy_config.compile_model = cfg.use_torch_compile
if hasattr(full_config, "compile_model"):
full_config.compile_model = cfg.use_torch_compile
# Handle PEFT models
if policy_config.use_peft:
if full_config.use_peft:
from peft import PeftConfig, PeftModel
peft_path = cfg.policy.pretrained_path
peft_config = PeftConfig.from_pretrained(peft_path)
policy = policy_class.from_pretrained(
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
)
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
else:
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
# Enable RTC on the policy
if use_rtc:
@@ -136,10 +168,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
except Exception as e:
logger.warning("Failed to apply torch.compile: %s", e)
# --- Observation features (filter to .pos joints + camera streams) ---
# --- Observation features ---
# Hardware-level features: camera features are tuples (H, W, C), state
# features are the ``float`` type. This is the canonical pattern used
# throughout the codebase (see feature_utils.py:hw_to_dataset_features).
all_obs_features = robot.observation_features
observation_features_hw = {
k: v for k, v in all_obs_features.items() if k.endswith(".pos") or isinstance(v, tuple)
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
}
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
@@ -163,6 +198,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
# Action keys
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
# Ordered action keys (reconcile policy vs dataset ordering)
policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
action_keys,
)
# --- Dataset ---
dataset = None
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
@@ -180,6 +222,14 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
* len(robot.cameras if hasattr(robot, "cameras") else []),
)
else:
# Add intervention column for DAgger strategy
if isinstance(cfg.strategy, DAggerStrategyConfig):
dataset_features["intervention"] = {
"dtype": "int64",
"shape": (1,),
"names": None,
}
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
@@ -206,11 +256,11 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
)
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
policy_cfg=policy_config,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats,
preprocessor_overrides={
"device_processor": {"device": cfg.device or cfg.policy.device},
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
},
)
@@ -230,5 +280,6 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
shutdown_event=shutdown_event,
dataset_features=dataset_features,
action_keys=action_keys,
ordered_action_keys=ordered_action_keys,
hw_features=hw_features,
)
+26 -54
View File
@@ -31,10 +31,10 @@ from typing import Any
import torch
from lerobot.common.control_utils import prepare_observation_for_inference
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker
from lerobot.policies.rtc import ActionQueue, LatencyTracker
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import (
NormalizerProcessorStep,
PolicyProcessorPipeline,
@@ -93,26 +93,6 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded
def _resolve_action_key_order(
policy_action_names: list[str] | None, dataset_action_names: list[str]
) -> list[str]:
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
if not policy_action_names:
return dataset_action_names
policy_action_names = list(policy_action_names)
if len(policy_action_names) != len(dataset_action_names):
logger.warning(
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
len(policy_action_names),
len(dataset_action_names),
)
return dataset_action_names
if set(dataset_action_names) != set(policy_action_names):
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
return dataset_action_names
return policy_action_names
# ---------------------------------------------------------------------------
# InferenceEngine
# ---------------------------------------------------------------------------
@@ -143,12 +123,13 @@ class InferenceEngine:
Control loop frequency.
device:
Torch device string.
interpolator:
Action interpolator (used only in RTC mode for the actor loop).
use_torch_compile:
Whether torch.compile warmup is needed.
compile_warmup_inferences:
Number of warmup inferences before live rollout.
rtc_queue_threshold:
Maximum RTC action queue size before the background thread
pauses generation. Prevents unbounded queue growth.
"""
def __init__(
@@ -163,9 +144,9 @@ class InferenceEngine:
task: str,
fps: float,
device: str | None,
interpolator: ActionInterpolator | None = None,
use_torch_compile: bool = False,
compile_warmup_inferences: int = 2,
rtc_queue_threshold: int = 30,
) -> None:
self._policy = policy
self._preprocessor = preprocessor
@@ -177,9 +158,9 @@ class InferenceEngine:
self._task = task
self._fps = fps
self._device = device or "cpu"
self._interpolator = interpolator
self._use_torch_compile = use_torch_compile
self._compile_warmup_inferences = compile_warmup_inferences
self._rtc_queue_threshold = rtc_queue_threshold
# RTC state
self._use_rtc = rtc_config.enabled
@@ -270,8 +251,6 @@ class InferenceEngine:
self._postprocessor.reset()
if self._use_rtc:
self._action_queue = ActionQueue(self._rtc_config)
if self._interpolator is not None:
self._interpolator.reset()
# ------------------------------------------------------------------
# Sync inference
@@ -329,8 +308,7 @@ class InferenceEngine:
try:
latency_tracker = LatencyTracker()
time_per_chunk = 1.0 / self._fps
threshold = 30
policy_device = self._policy.config.device
policy_device = torch.device(self._device)
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
inference_count = 0
@@ -347,7 +325,7 @@ class InferenceEngine:
time.sleep(0.01)
continue
if queue.qsize() <= threshold:
if queue.qsize() <= self._rtc_queue_threshold:
try:
current_time = time.perf_counter()
idx_before = queue.get_action_index()
@@ -356,35 +334,29 @@ class InferenceEngine:
latency = latency_tracker.max()
delay = math.ceil(latency / time_per_chunk) if latency else 0
# Build observation batch
# Build observation batch using the same pipeline as sync inference
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
for name in obs_batch:
obs_batch[name] = torch.from_numpy(obs_batch[name])
if "image" in name:
obs_batch[name] = obs_batch[name].float() / 255
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device)
obs_batch = prepare_observation_for_inference(
obs_batch, policy_device, self._task, self._robot.robot_type
)
# predict_action_chunk expects batched task format
obs_batch["task"] = [self._task]
obs_batch["robot_type"] = self._obs_holder.get("robot_type", "unknown")
preprocessed = self._preprocessor(obs_batch)
# Re-anchor leftover for relative-action policies
if (
prev_actions is not None
and self._relative_step is not None
and OBS_STATE in obs_batch
):
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_abs,
current_state=obs_batch[OBS_STATE],
relative_step=self._relative_step,
normalizer_step=self._normalizer_step,
policy_device=policy_device,
)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
prev_actions = _reanchor_relative_rtc_prefix(
prev_actions_absolute=prev_abs,
current_state=state_tensor,
relative_step=self._relative_step,
normalizer_step=self._normalizer_step,
policy_device=policy_device,
)
if prev_actions is not None:
prev_actions = _normalize_prev_actions_length(
+80 -1
View File
@@ -12,16 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rollout strategy ABC and factory."""
"""Rollout strategy ABC, factory, and shared inference helper."""
from __future__ import annotations
import abc
from typing import TYPE_CHECKING
import torch
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
if TYPE_CHECKING:
from lerobot.rollout.configs import RolloutStrategyConfig
from lerobot.rollout.context import RolloutContext
from lerobot.rollout.inference import InferenceEngine
class RolloutStrategy(abc.ABC):
@@ -48,6 +56,77 @@ class RolloutStrategy(abc.ABC):
"""Cleanup: save dataset, stop threads, disconnect hardware."""
# ---------------------------------------------------------------------------
# Shared inference helper
# ---------------------------------------------------------------------------
def infer_action(
engine: InferenceEngine,
obs_processed: dict,
obs_raw: dict,
ctx: RolloutContext,
interpolator: ActionInterpolator,
ordered_keys: list[str],
features: dict,
) -> dict | None:
"""Run one policy inference step and send the resulting action to the robot.
Handles both sync and RTC backends. Uses the interpolator for smooth
control at higher-than-inference rates (works with any multiplier,
including 1 where it acts as a pass-through).
Parameters
----------
engine:
The inference engine (sync or RTC).
obs_processed:
Observation dict after ``robot_observation_processor``.
obs_raw:
Raw observation dict (needed by ``robot_action_processor``).
ctx:
Rollout context.
interpolator:
Action interpolator for Nx control rate.
ordered_keys:
Ordered action feature names (policy-to-robot mapping).
features:
Feature specification dict for ``build_dataset_frame`` /
``make_robot_action``. Use ``dataset.features`` when recording,
``ctx.dataset_features`` otherwise.
Returns
-------
Action dict sent to the robot, or ``None`` if no action was
available (empty RTC queue, interpolator buffer not ready).
"""
if engine.is_rtc:
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
else:
if interpolator.needs_new_action():
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, features)
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
interpolator.add(action_t)
interp = interpolator.get()
if interp is not None:
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.robot_action_processor((action_dict, obs_raw))
ctx.robot_wrapper.send_action(processed)
return action_dict
return None
# ---------------------------------------------------------------------------
# Strategy factory
# ---------------------------------------------------------------------------
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
"""Instantiate the appropriate strategy from a config object."""
from lerobot.rollout.configs import (
+19 -41
View File
@@ -20,14 +20,11 @@ import logging
import time
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..context import RolloutContext
from ..inference import InferenceEngine, _resolve_action_key_order
from . import RolloutStrategy
from ..inference import InferenceEngine
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -43,9 +40,10 @@ class BaseStrategy(RolloutStrategy):
def __init__(self, config):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
def setup(self, ctx: RolloutContext) -> None:
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
@@ -58,7 +56,6 @@ class BaseStrategy(RolloutStrategy):
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
interpolator=interpolator,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
@@ -69,16 +66,10 @@ class BaseStrategy(RolloutStrategy):
engine = self._engine
cfg = ctx.cfg
robot = ctx.robot_wrapper
action_keys = ctx.action_keys
interpolator = self._interpolator
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(cfg.fps)
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
ordered_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
action_keys,
)
ordered_keys = ctx.ordered_action_keys
start_time = time.perf_counter()
warmup_flushed = False
@@ -98,34 +89,21 @@ class BaseStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
# Wait for torch.compile warmup before running live inference
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is not None:
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
else:
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features)
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
+80 -91
View File
@@ -29,26 +29,26 @@ Keyboard Controls:
from __future__ import annotations
import contextlib
import logging
import time
from typing import Any
import torch
import numpy as np
from lerobot.common.control_utils import is_headless, predict_action
from lerobot.common.control_utils import is_headless
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.processor import RobotProcessorPipeline
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import log_say
from ..configs import DAggerStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine, _resolve_action_key_order
from . import RolloutStrategy
from ..inference import InferenceEngine
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -94,8 +94,19 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp
time.sleep(1 / fps)
def _reset_loop(robot, teleop, events: dict, fps: int) -> None:
"""Reset period where the human repositions the environment."""
def _reset_loop(
robot,
teleop,
events: dict,
fps: int,
teleop_action_processor: RobotProcessorPipeline,
robot_action_processor: RobotProcessorPipeline,
) -> None:
"""Reset period where the human repositions the environment.
All teleop actions flow through the processor pipelines to ensure
correct behavior for EE-space robots.
"""
logger.info("RESET — press any key to enable teleoperation")
events["in_reset"] = True
@@ -117,8 +128,11 @@ def _reset_loop(robot, teleop, events: dict, fps: int) -> None:
while not events["start_next_episode"] and not events["stop_recording"]:
loop_start = time.perf_counter()
obs = robot.get_observation()
action = teleop.get_action()
robot.send_action(action)
processed_teleop = teleop_action_processor((action, obs))
robot_action_to_send = robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
events["in_reset"] = False
@@ -251,6 +265,10 @@ class DAggerStrategy(RolloutStrategy):
Supports both synchronous and RTC inference backends.
All actions (policy and teleop) flow through the appropriate
processor pipelines, supporting EE-space recording.
Intervention frames are tagged with ``intervention=1`` (int64) in
the dataset to allow downstream BC training to distinguish
autonomous from human-corrected data.
"""
config: DAggerStrategyConfig
@@ -258,11 +276,12 @@ class DAggerStrategy(RolloutStrategy):
def __init__(self, config: DAggerStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._listener = None
self._events: dict[str, Any] = {}
def setup(self, ctx: RolloutContext) -> None:
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
@@ -275,7 +294,6 @@ class DAggerStrategy(RolloutStrategy):
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
interpolator=interpolator,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
@@ -293,7 +311,6 @@ class DAggerStrategy(RolloutStrategy):
logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop")
def run(self, ctx: RolloutContext) -> None:
engine = self._engine
dataset = ctx.dataset
events = self._events
teleop = ctx.teleop
@@ -317,13 +334,18 @@ class DAggerStrategy(RolloutStrategy):
recorded += 1
if recorded < self.config.num_episodes and not events["stop_recording"]:
_reset_loop(ctx.robot_wrapper, teleop, events, int(ctx.cfg.fps))
_reset_loop(
ctx.robot_wrapper,
teleop,
events,
int(ctx.cfg.fps),
ctx.teleop_action_processor,
ctx.robot_action_processor,
)
finally:
try:
with contextlib.suppress(Exception):
dataset.save_episode()
except Exception:
pass
def teardown(self, ctx: RolloutContext) -> None:
log_say("Stop recording", self.config.play_sounds, blocking=True)
@@ -360,27 +382,22 @@ class DAggerStrategy(RolloutStrategy):
teleop = ctx.teleop
dataset = ctx.dataset
events = self._events
interpolator = self._interpolator
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(cfg.fps)
stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False
record_stride = max(1, cfg.interpolation_multiplier)
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
ordered_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
ctx.action_keys,
)
dataset_action_keys = list(dataset.features.get(ACTION, {}).get("names", ctx.action_keys))
ordered_keys = ctx.ordered_action_keys
features = dataset.features
engine.reset()
interpolator.reset()
_teleop_disable_torque(teleop)
was_paused = False
waiting_for_takeover = False
last_action: dict[str, Any] | None = None
robot_action: dict[str, Any] = {}
frame_buffer: list[dict] = []
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
@@ -444,7 +461,7 @@ class DAggerStrategy(RolloutStrategy):
# --- Get observation ---
obs = robot.get_observation()
obs_processed = ctx.robot_observation_processor(obs)
obs_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
# --- CORRECTION: human teleop control ---
if events["correction_active"]:
@@ -452,9 +469,14 @@ class DAggerStrategy(RolloutStrategy):
processed_teleop = ctx.teleop_action_processor((teleop_action, obs))
robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs))
robot.send_action(robot_action_to_send)
action_frame = build_dataset_frame(dataset.features, processed_teleop, prefix=ACTION)
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {**obs_frame, **action_frame, "task": task_str}
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([1], dtype=np.int64),
}
if stream_online:
dataset.add_frame(frame)
else:
@@ -471,73 +493,40 @@ class DAggerStrategy(RolloutStrategy):
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
timestamp = time.perf_counter() - start_t
continue
# Wait for torch.compile warmup
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
timestamp = time.perf_counter() - start_t
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
interp = interpolator.get()
if interp is not None:
robot_action = {
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
if action_dict is not None:
last_action = ctx.robot_action_processor((action_dict, obs))
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {
**obs_frame,
**action_frame,
"task": task_str,
"intervention": np.array([0], dtype=np.int64),
}
processed = ctx.robot_action_processor((robot_action, obs))
robot.send_action(processed)
last_action = processed
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {**obs_frame, **action_frame, "task": task_str}
if stream_online:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
record_tick += 1
else:
# Sync inference
if interpolator.needs_new_action():
device = get_safe_torch_device(cfg.device)
action_tensor = predict_action(
observation=obs_frame,
policy=ctx.policy,
device=device,
preprocessor=ctx.preprocessor,
postprocessor=ctx.postprocessor,
use_amp=ctx.policy.config.use_amp,
task=task_str,
robot_type=robot.robot_type,
)
robot_action = make_robot_action(action_tensor, dataset.features)
action_t = torch.tensor([robot_action[k] for k in dataset_action_keys])
interpolator.add(action_t)
interp = interpolator.get()
if interp is not None:
robot_action = {k: interp[i].item() for i, k in enumerate(dataset_action_keys)}
processed = ctx.robot_action_processor((robot_action, obs))
robot.send_action(processed)
last_action = processed
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
if record_tick % record_stride == 0:
frame = {**obs_frame, **action_frame, "task": task_str}
if stream_online:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
record_tick += 1
if stream_online:
dataset.add_frame(frame)
else:
frame_buffer.append(frame)
record_tick += 1
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
+40 -60
View File
@@ -16,21 +16,22 @@
from __future__ import annotations
import contextlib
import logging
import time
from threading import Event as ThreadingEvent
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..configs import HighlightStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine, _resolve_action_key_order
from ..inference import InferenceEngine
from ..ring_buffer import RolloutRingBuffer
from . import RolloutStrategy
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -54,13 +55,14 @@ class HighlightStrategy(RolloutStrategy):
def __init__(self, config: HighlightStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._ring: RolloutRingBuffer | None = None
self._listener = None
self._save_requested = False
self._recording_live = False
self._save_requested = ThreadingEvent()
self._recording_live = ThreadingEvent()
def setup(self, ctx: RolloutContext) -> None:
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
@@ -73,7 +75,6 @@ class HighlightStrategy(RolloutStrategy):
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
interpolator=interpolator,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
@@ -97,17 +98,12 @@ class HighlightStrategy(RolloutStrategy):
cfg = ctx.cfg
robot = ctx.robot_wrapper
dataset = ctx.dataset
action_keys = ctx.action_keys
ring = self._ring
interpolator = self._interpolator
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(cfg.fps)
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
ordered_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
action_keys,
)
ordered_keys = ctx.ordered_action_keys
features = dataset.features
if engine.is_rtc:
engine.resume()
@@ -126,70 +122,58 @@ class HighlightStrategy(RolloutStrategy):
obs = robot.get_observation()
obs_processed = ctx.robot_observation_processor(obs)
action_dict = None
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is not None:
action_dict = {
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
}
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
else:
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
# Build frame for ring buffer / live recording
if action_dict is not None:
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
# Handle save key toggle
if self._save_requested:
self._save_requested = False
if not self._recording_live:
if self._save_requested.is_set():
self._save_requested.clear()
if not self._recording_live.is_set():
logger.info(
"Flushing ring buffer (%d frames) + starting live recording", len(ring)
)
for buffered_frame in ring.drain():
dataset.add_frame(buffered_frame)
self._recording_live = True
self._recording_live.set()
else:
# Save current frame as the last frame of the episode
dataset.add_frame(frame)
dataset.save_episode()
logger.info("Episode saved")
self._recording_live = False
self._recording_live.clear()
engine.reset()
interpolator.reset()
if engine.is_rtc:
engine.resume()
if self._recording_live:
if self._recording_live.is_set():
dataset.add_frame(frame)
else:
# Current frame goes into the ring buffer for next potential save.
ring.append(frame)
dt = time.perf_counter() - loop_start
@@ -197,11 +181,9 @@ class HighlightStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
if self._recording_live:
try:
if self._recording_live.is_set():
with contextlib.suppress(Exception):
dataset.save_episode()
except Exception:
pass
def teardown(self, ctx: RolloutContext) -> None:
if self._engine is not None:
@@ -237,13 +219,11 @@ class HighlightStrategy(RolloutStrategy):
save_key = self.config.save_key
def on_press(key):
try:
with contextlib.suppress(Exception):
if hasattr(key, "char") and key.char == save_key:
self._save_requested = True
self._save_requested.set()
elif key == keyboard.Key.esc:
self._save_requested = False
except Exception:
pass
self._save_requested.clear()
self._listener = keyboard.Listener(on_press=on_press)
self._listener.start()
+35 -50
View File
@@ -16,21 +16,21 @@
from __future__ import annotations
import contextlib
import logging
import time
from threading import Thread
from lerobot.datasets import VideoEncodingManager
from lerobot.policies.rtc import ActionInterpolator
from lerobot.policies.utils import make_robot_action
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.feature_utils import build_dataset_frame
from lerobot.utils.robot_utils import precise_sleep
from ..configs import SentryStrategyConfig
from ..context import RolloutContext
from ..inference import InferenceEngine, _resolve_action_key_order
from . import RolloutStrategy
from ..inference import InferenceEngine
from . import RolloutStrategy, infer_action
logger = logging.getLogger(__name__)
@@ -55,10 +55,12 @@ class SentryStrategy(RolloutStrategy):
def __init__(self, config: SentryStrategyConfig):
super().__init__(config)
self._engine: InferenceEngine | None = None
self._interpolator: ActionInterpolator | None = None
self._push_thread: Thread | None = None
self._needs_push: bool = False
def setup(self, ctx: RolloutContext) -> None:
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
self._engine = InferenceEngine(
policy=ctx.policy,
@@ -71,7 +73,6 @@ class SentryStrategy(RolloutStrategy):
task=ctx.cfg.task,
fps=ctx.cfg.fps,
device=ctx.cfg.device,
interpolator=interpolator,
use_torch_compile=ctx.cfg.use_torch_compile,
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
)
@@ -87,16 +88,11 @@ class SentryStrategy(RolloutStrategy):
cfg = ctx.cfg
robot = ctx.robot_wrapper
dataset = ctx.dataset
action_keys = ctx.action_keys
interpolator = self._interpolator
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
control_interval = interpolator.get_control_interval(cfg.fps)
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
ordered_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
action_keys,
)
ordered_keys = ctx.ordered_action_keys
features = dataset.features
if engine.is_rtc:
engine.resume()
@@ -117,45 +113,31 @@ class SentryStrategy(RolloutStrategy):
obs = robot.get_observation()
obs_processed = ctx.robot_observation_processor(obs)
action_dict = None
if engine.is_rtc:
engine.update_observation(obs_processed)
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
continue
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if cfg.use_torch_compile and not warmup_flushed:
engine.reset()
interpolator.reset()
warmup_flushed = True
if engine.is_rtc:
engine.resume()
if interpolator.needs_new_action():
action_tensor = engine.consume_rtc_action()
if action_tensor is not None:
interpolator.add(action_tensor.cpu())
interp = interpolator.get()
if interp is not None:
action_dict = {
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
}
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
else:
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
action_tensor = engine.get_action_sync(obs_frame)
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
processed = ctx.robot_action_processor((action_dict, obs))
robot.send_action(processed)
action_dict = infer_action(
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
)
# Record frame
if action_dict is not None:
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION)
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
frame = {**obs_frame, **action_frame, "task": task_str}
dataset.add_frame(frame)
@@ -164,6 +146,7 @@ class SentryStrategy(RolloutStrategy):
if elapsed >= self.config.episode_duration_s:
dataset.save_episode()
episodes_since_push += 1
self._needs_push = True
logger.info("Episode saved (total: %d)", dataset.num_episodes)
if episodes_since_push >= self.config.upload_every_n_episodes:
@@ -181,26 +164,27 @@ class SentryStrategy(RolloutStrategy):
precise_sleep(sleep_t)
finally:
try:
with contextlib.suppress(Exception):
dataset.save_episode()
except Exception:
pass
self._needs_push = True
def teardown(self, ctx: RolloutContext) -> None:
if self._engine is not None:
self._engine.stop()
# Wait for any in-flight background push
if self._push_thread is not None and self._push_thread.is_alive():
self._push_thread.join(timeout=60)
if ctx.dataset is not None:
ctx.dataset.finalize()
if ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
# Only push if there are unsaved changes since last background push
if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
ctx.dataset.push_to_hub(
tags=ctx.cfg.dataset.tags,
private=ctx.cfg.dataset.private,
)
if self._push_thread is not None and self._push_thread.is_alive():
self._push_thread.join(timeout=60)
if ctx.robot.is_connected:
ctx.robot.disconnect()
if ctx.teleop is not None and ctx.teleop.is_connected:
@@ -219,6 +203,7 @@ class SentryStrategy(RolloutStrategy):
tags=cfg.dataset.tags if cfg.dataset else None,
private=cfg.dataset.private if cfg.dataset else False,
)
self._needs_push = False
logger.info("Background push to hub complete")
except Exception as e:
logger.error("Background push failed: %s", e)