mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
some more iterations
This commit is contained in:
@@ -222,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -222,7 +222,6 @@ def main():
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
finally:
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
|
||||
@@ -42,12 +42,32 @@ from lerobot.robots import Robot, make_robot_from_config
|
||||
from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config
|
||||
from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features
|
||||
|
||||
from .configs import BaseStrategyConfig, RolloutConfig
|
||||
from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig
|
||||
from .robot_wrapper import ThreadSafeRobot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_action_key_order(
|
||||
policy_action_names: list[str] | None, dataset_action_names: list[str]
|
||||
) -> list[str]:
|
||||
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
|
||||
if not policy_action_names:
|
||||
return dataset_action_names
|
||||
policy_action_names = list(policy_action_names)
|
||||
if len(policy_action_names) != len(dataset_action_names):
|
||||
logger.warning(
|
||||
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
|
||||
len(policy_action_names),
|
||||
len(dataset_action_names),
|
||||
)
|
||||
return dataset_action_names
|
||||
if set(dataset_action_names) != set(policy_action_names):
|
||||
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
|
||||
return dataset_action_names
|
||||
return policy_action_names
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutContext:
|
||||
"""Bundle of shared resources passed to every rollout strategy.
|
||||
@@ -69,6 +89,7 @@ class RolloutContext:
|
||||
shutdown_event: Event = field(default_factory=Event)
|
||||
dataset_features: dict = field(default_factory=dict)
|
||||
action_keys: list[str] = field(default_factory=list)
|
||||
ordered_action_keys: list[str] = field(default_factory=list)
|
||||
hw_features: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -92,26 +113,37 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# --- Policy ---
|
||||
# Use cfg.policy directly (already loaded in RolloutConfig.__post_init__)
|
||||
# instead of reloading from disk.
|
||||
policy_config = cfg.policy
|
||||
use_rtc = cfg.rtc.enabled
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
policy_class = get_policy_class(policy_config.type)
|
||||
|
||||
# Reload config from pretrained path for full model parameters
|
||||
full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
# Merge any CLI overrides from cfg.policy into full_config
|
||||
for attr in ("device", "use_amp"):
|
||||
if hasattr(cfg.policy, attr) and hasattr(full_config, attr):
|
||||
cli_val = getattr(cfg.policy, attr)
|
||||
if cli_val is not None:
|
||||
setattr(full_config, attr, cli_val)
|
||||
|
||||
# Set compile_model for pi0/pi05
|
||||
if hasattr(policy_config, "compile_model"):
|
||||
policy_config.compile_model = cfg.use_torch_compile
|
||||
if hasattr(full_config, "compile_model"):
|
||||
full_config.compile_model = cfg.use_torch_compile
|
||||
|
||||
# Handle PEFT models
|
||||
if policy_config.use_peft:
|
||||
if full_config.use_peft:
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
peft_path = cfg.policy.pretrained_path
|
||||
peft_config = PeftConfig.from_pretrained(peft_path)
|
||||
policy = policy_class.from_pretrained(
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=policy_config
|
||||
pretrained_name_or_path=peft_config.base_model_name_or_path, config=full_config
|
||||
)
|
||||
policy = PeftModel.from_pretrained(policy, peft_path, config=peft_config)
|
||||
else:
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=policy_config)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config)
|
||||
|
||||
# Enable RTC on the policy
|
||||
if use_rtc:
|
||||
@@ -136,10 +168,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
except Exception as e:
|
||||
logger.warning("Failed to apply torch.compile: %s", e)
|
||||
|
||||
# --- Observation features (filter to .pos joints + camera streams) ---
|
||||
# --- Observation features ---
|
||||
# Hardware-level features: camera features are tuples (H, W, C), state
|
||||
# features are the ``float`` type. This is the canonical pattern used
|
||||
# throughout the codebase (see feature_utils.py:hw_to_dataset_features).
|
||||
all_obs_features = robot.observation_features
|
||||
observation_features_hw = {
|
||||
k: v for k, v in all_obs_features.items() if k.endswith(".pos") or isinstance(v, tuple)
|
||||
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
|
||||
}
|
||||
|
||||
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
|
||||
@@ -163,6 +198,13 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
# Action keys
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
|
||||
# Ordered action keys (reconcile policy vs dataset ordering)
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
ordered_action_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
action_keys,
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset = None
|
||||
if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig):
|
||||
@@ -180,6 +222,14 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
* len(robot.cameras if hasattr(robot, "cameras") else []),
|
||||
)
|
||||
else:
|
||||
# Add intervention column for DAgger strategy
|
||||
if isinstance(cfg.strategy, DAggerStrategyConfig):
|
||||
dataset_features["intervention"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.dataset.fps,
|
||||
@@ -206,11 +256,11 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
policy_cfg=policy_config,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device or cfg.policy.device},
|
||||
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
||||
},
|
||||
)
|
||||
@@ -230,5 +280,6 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC
|
||||
shutdown_event=shutdown_event,
|
||||
dataset_features=dataset_features,
|
||||
action_keys=action_keys,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
hw_features=hw_features,
|
||||
)
|
||||
|
||||
@@ -31,10 +31,10 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.control_utils import prepare_observation_for_inference
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc import ActionQueue, LatencyTracker
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import (
|
||||
NormalizerProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
@@ -93,26 +93,6 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
|
||||
return padded
|
||||
|
||||
|
||||
def _resolve_action_key_order(
|
||||
policy_action_names: list[str] | None, dataset_action_names: list[str]
|
||||
) -> list[str]:
|
||||
"""Choose action name ordering for mapping policy tensor outputs to robot action dicts."""
|
||||
if not policy_action_names:
|
||||
return dataset_action_names
|
||||
policy_action_names = list(policy_action_names)
|
||||
if len(policy_action_names) != len(dataset_action_names):
|
||||
logger.warning(
|
||||
"policy.action_feature_names length (%d) != dataset action dim (%d); using dataset order",
|
||||
len(policy_action_names),
|
||||
len(dataset_action_names),
|
||||
)
|
||||
return dataset_action_names
|
||||
if set(dataset_action_names) != set(policy_action_names):
|
||||
logger.warning("policy.action_feature_names keys don't match dataset; using dataset order")
|
||||
return dataset_action_names
|
||||
return policy_action_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# InferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -143,12 +123,13 @@ class InferenceEngine:
|
||||
Control loop frequency.
|
||||
device:
|
||||
Torch device string.
|
||||
interpolator:
|
||||
Action interpolator (used only in RTC mode for the actor loop).
|
||||
use_torch_compile:
|
||||
Whether torch.compile warmup is needed.
|
||||
compile_warmup_inferences:
|
||||
Number of warmup inferences before live rollout.
|
||||
rtc_queue_threshold:
|
||||
Maximum RTC action queue size before the background thread
|
||||
pauses generation. Prevents unbounded queue growth.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -163,9 +144,9 @@ class InferenceEngine:
|
||||
task: str,
|
||||
fps: float,
|
||||
device: str | None,
|
||||
interpolator: ActionInterpolator | None = None,
|
||||
use_torch_compile: bool = False,
|
||||
compile_warmup_inferences: int = 2,
|
||||
rtc_queue_threshold: int = 30,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._preprocessor = preprocessor
|
||||
@@ -177,9 +158,9 @@ class InferenceEngine:
|
||||
self._task = task
|
||||
self._fps = fps
|
||||
self._device = device or "cpu"
|
||||
self._interpolator = interpolator
|
||||
self._use_torch_compile = use_torch_compile
|
||||
self._compile_warmup_inferences = compile_warmup_inferences
|
||||
self._rtc_queue_threshold = rtc_queue_threshold
|
||||
|
||||
# RTC state
|
||||
self._use_rtc = rtc_config.enabled
|
||||
@@ -270,8 +251,6 @@ class InferenceEngine:
|
||||
self._postprocessor.reset()
|
||||
if self._use_rtc:
|
||||
self._action_queue = ActionQueue(self._rtc_config)
|
||||
if self._interpolator is not None:
|
||||
self._interpolator.reset()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sync inference
|
||||
@@ -329,8 +308,7 @@ class InferenceEngine:
|
||||
try:
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / self._fps
|
||||
threshold = 30
|
||||
policy_device = self._policy.config.device
|
||||
policy_device = torch.device(self._device)
|
||||
|
||||
warmup_required = max(1, self._compile_warmup_inferences) if self._use_torch_compile else 0
|
||||
inference_count = 0
|
||||
@@ -347,7 +325,7 @@ class InferenceEngine:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if queue.qsize() <= threshold:
|
||||
if queue.qsize() <= self._rtc_queue_threshold:
|
||||
try:
|
||||
current_time = time.perf_counter()
|
||||
idx_before = queue.get_action_index()
|
||||
@@ -356,35 +334,29 @@ class InferenceEngine:
|
||||
latency = latency_tracker.max()
|
||||
delay = math.ceil(latency / time_per_chunk) if latency else 0
|
||||
|
||||
# Build observation batch
|
||||
# Build observation batch using the same pipeline as sync inference
|
||||
obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation")
|
||||
for name in obs_batch:
|
||||
obs_batch[name] = torch.from_numpy(obs_batch[name])
|
||||
if "image" in name:
|
||||
obs_batch[name] = obs_batch[name].float() / 255
|
||||
obs_batch[name] = obs_batch[name].permute(2, 0, 1).contiguous()
|
||||
obs_batch[name] = obs_batch[name].unsqueeze(0).to(policy_device)
|
||||
|
||||
obs_batch = prepare_observation_for_inference(
|
||||
obs_batch, policy_device, self._task, self._robot.robot_type
|
||||
)
|
||||
# predict_action_chunk expects batched task format
|
||||
obs_batch["task"] = [self._task]
|
||||
obs_batch["robot_type"] = self._obs_holder.get("robot_type", "unknown")
|
||||
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
# Re-anchor leftover for relative-action policies
|
||||
if (
|
||||
prev_actions is not None
|
||||
and self._relative_step is not None
|
||||
and OBS_STATE in obs_batch
|
||||
):
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=obs_batch[OBS_STATE],
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
state_tensor = preprocessed.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
prev_actions = _reanchor_relative_rtc_prefix(
|
||||
prev_actions_absolute=prev_abs,
|
||||
current_state=state_tensor,
|
||||
relative_step=self._relative_step,
|
||||
normalizer_step=self._normalizer_step,
|
||||
policy_device=policy_device,
|
||||
)
|
||||
|
||||
if prev_actions is not None:
|
||||
prev_actions = _normalize_prev_actions_length(
|
||||
|
||||
@@ -12,16 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Rollout strategy ABC and factory."""
|
||||
"""Rollout strategy ABC, factory, and shared inference helper."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rollout.configs import RolloutStrategyConfig
|
||||
from lerobot.rollout.context import RolloutContext
|
||||
from lerobot.rollout.inference import InferenceEngine
|
||||
|
||||
|
||||
class RolloutStrategy(abc.ABC):
|
||||
@@ -48,6 +56,77 @@ class RolloutStrategy(abc.ABC):
|
||||
"""Cleanup: save dataset, stop threads, disconnect hardware."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared inference helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_action(
|
||||
engine: InferenceEngine,
|
||||
obs_processed: dict,
|
||||
obs_raw: dict,
|
||||
ctx: RolloutContext,
|
||||
interpolator: ActionInterpolator,
|
||||
ordered_keys: list[str],
|
||||
features: dict,
|
||||
) -> dict | None:
|
||||
"""Run one policy inference step and send the resulting action to the robot.
|
||||
|
||||
Handles both sync and RTC backends. Uses the interpolator for smooth
|
||||
control at higher-than-inference rates (works with any multiplier,
|
||||
including 1 where it acts as a pass-through).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine:
|
||||
The inference engine (sync or RTC).
|
||||
obs_processed:
|
||||
Observation dict after ``robot_observation_processor``.
|
||||
obs_raw:
|
||||
Raw observation dict (needed by ``robot_action_processor``).
|
||||
ctx:
|
||||
Rollout context.
|
||||
interpolator:
|
||||
Action interpolator for Nx control rate.
|
||||
ordered_keys:
|
||||
Ordered action feature names (policy-to-robot mapping).
|
||||
features:
|
||||
Feature specification dict for ``build_dataset_frame`` /
|
||||
``make_robot_action``. Use ``dataset.features`` when recording,
|
||||
``ctx.dataset_features`` otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Action dict sent to the robot, or ``None`` if no action was
|
||||
available (empty RTC queue, interpolator buffer not ready).
|
||||
"""
|
||||
if engine.is_rtc:
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
else:
|
||||
if interpolator.needs_new_action():
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, features)
|
||||
action_t = torch.tensor([action_dict[k] for k in ordered_keys])
|
||||
interpolator.add(action_t)
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.robot_action_processor((action_dict, obs_raw))
|
||||
ctx.robot_wrapper.send_action(processed)
|
||||
return action_dict
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategy factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||
"""Instantiate the appropriate strategy from a config object."""
|
||||
from lerobot.rollout.configs import (
|
||||
|
||||
@@ -20,14 +20,11 @@ import logging
|
||||
import time
|
||||
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine, _resolve_action_key_order
|
||||
from . import RolloutStrategy
|
||||
from ..inference import InferenceEngine
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,9 +40,10 @@ class BaseStrategy(RolloutStrategy):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
@@ -58,7 +56,6 @@ class BaseStrategy(RolloutStrategy):
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
interpolator=interpolator,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
@@ -69,16 +66,10 @@ class BaseStrategy(RolloutStrategy):
|
||||
engine = self._engine
|
||||
cfg = ctx.cfg
|
||||
robot = ctx.robot_wrapper
|
||||
action_keys = ctx.action_keys
|
||||
interpolator = self._interpolator
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
ordered_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
action_keys,
|
||||
)
|
||||
ordered_keys = ctx.ordered_action_keys
|
||||
|
||||
start_time = time.perf_counter()
|
||||
warmup_flushed = False
|
||||
@@ -98,34 +89,21 @@ class BaseStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
# Wait for torch.compile warmup before running live inference
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)}
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
|
||||
else:
|
||||
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
|
||||
@@ -29,26 +29,26 @@ Keyboard Controls:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.control_utils import is_headless, predict_action
|
||||
from lerobot.common.control_utils import is_headless
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.processor import RobotProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.device_utils import get_safe_torch_device
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
from ..configs import DAggerStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine, _resolve_action_key_order
|
||||
from . import RolloutStrategy
|
||||
from ..inference import InferenceEngine
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,8 +94,19 @@ def _teleop_smooth_move_to(teleop, target_pos: dict, duration_s: float = 2.0, fp
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
def _reset_loop(robot, teleop, events: dict, fps: int) -> None:
|
||||
"""Reset period where the human repositions the environment."""
|
||||
def _reset_loop(
|
||||
robot,
|
||||
teleop,
|
||||
events: dict,
|
||||
fps: int,
|
||||
teleop_action_processor: RobotProcessorPipeline,
|
||||
robot_action_processor: RobotProcessorPipeline,
|
||||
) -> None:
|
||||
"""Reset period where the human repositions the environment.
|
||||
|
||||
All teleop actions flow through the processor pipelines to ensure
|
||||
correct behavior for EE-space robots.
|
||||
"""
|
||||
logger.info("RESET — press any key to enable teleoperation")
|
||||
|
||||
events["in_reset"] = True
|
||||
@@ -117,8 +128,11 @@ def _reset_loop(robot, teleop, events: dict, fps: int) -> None:
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
loop_start = time.perf_counter()
|
||||
obs = robot.get_observation()
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
processed_teleop = teleop_action_processor((action, obs))
|
||||
robot_action_to_send = robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
precise_sleep(1 / fps - (time.perf_counter() - loop_start))
|
||||
|
||||
events["in_reset"] = False
|
||||
@@ -251,6 +265,10 @@ class DAggerStrategy(RolloutStrategy):
|
||||
Supports both synchronous and RTC inference backends.
|
||||
All actions (policy and teleop) flow through the appropriate
|
||||
processor pipelines, supporting EE-space recording.
|
||||
|
||||
Intervention frames are tagged with ``intervention=1`` (int64) in
|
||||
the dataset to allow downstream BC training to distinguish
|
||||
autonomous from human-corrected data.
|
||||
"""
|
||||
|
||||
config: DAggerStrategyConfig
|
||||
@@ -258,11 +276,12 @@ class DAggerStrategy(RolloutStrategy):
|
||||
def __init__(self, config: DAggerStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._listener = None
|
||||
self._events: dict[str, Any] = {}
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
@@ -275,7 +294,6 @@ class DAggerStrategy(RolloutStrategy):
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
interpolator=interpolator,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
@@ -293,7 +311,6 @@ class DAggerStrategy(RolloutStrategy):
|
||||
logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop")
|
||||
|
||||
def run(self, ctx: RolloutContext) -> None:
|
||||
engine = self._engine
|
||||
dataset = ctx.dataset
|
||||
events = self._events
|
||||
teleop = ctx.teleop
|
||||
@@ -317,13 +334,18 @@ class DAggerStrategy(RolloutStrategy):
|
||||
recorded += 1
|
||||
|
||||
if recorded < self.config.num_episodes and not events["stop_recording"]:
|
||||
_reset_loop(ctx.robot_wrapper, teleop, events, int(ctx.cfg.fps))
|
||||
_reset_loop(
|
||||
ctx.robot_wrapper,
|
||||
teleop,
|
||||
events,
|
||||
int(ctx.cfg.fps),
|
||||
ctx.teleop_action_processor,
|
||||
ctx.robot_action_processor,
|
||||
)
|
||||
|
||||
finally:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
log_say("Stop recording", self.config.play_sounds, blocking=True)
|
||||
@@ -360,27 +382,22 @@ class DAggerStrategy(RolloutStrategy):
|
||||
teleop = ctx.teleop
|
||||
dataset = ctx.dataset
|
||||
events = self._events
|
||||
interpolator = self._interpolator
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False
|
||||
record_stride = max(1, cfg.interpolation_multiplier)
|
||||
|
||||
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
ordered_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
ctx.action_keys,
|
||||
)
|
||||
|
||||
dataset_action_keys = list(dataset.features.get(ACTION, {}).get("names", ctx.action_keys))
|
||||
ordered_keys = ctx.ordered_action_keys
|
||||
features = dataset.features
|
||||
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
_teleop_disable_torque(teleop)
|
||||
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
robot_action: dict[str, Any] = {}
|
||||
frame_buffer: list[dict] = []
|
||||
task_str = cfg.dataset.single_task if cfg.dataset else cfg.task
|
||||
|
||||
@@ -444,7 +461,7 @@ class DAggerStrategy(RolloutStrategy):
|
||||
# --- Get observation ---
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.robot_observation_processor(obs)
|
||||
obs_frame = build_dataset_frame(dataset.features, obs_processed, prefix=OBS_STR)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
|
||||
# --- CORRECTION: human teleop control ---
|
||||
if events["correction_active"]:
|
||||
@@ -452,9 +469,14 @@ class DAggerStrategy(RolloutStrategy):
|
||||
processed_teleop = ctx.teleop_action_processor((teleop_action, obs))
|
||||
robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs))
|
||||
robot.send_action(robot_action_to_send)
|
||||
action_frame = build_dataset_frame(dataset.features, processed_teleop, prefix=ACTION)
|
||||
action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([1], dtype=np.int64),
|
||||
}
|
||||
if stream_online:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
@@ -471,73 +493,40 @@ class DAggerStrategy(RolloutStrategy):
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
continue
|
||||
# Wait for torch.compile warmup
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
timestamp = time.perf_counter() - start_t
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
robot_action = {
|
||||
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
|
||||
if action_dict is not None:
|
||||
last_action = ctx.robot_action_processor((action_dict, obs))
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {
|
||||
**obs_frame,
|
||||
**action_frame,
|
||||
"task": task_str,
|
||||
"intervention": np.array([0], dtype=np.int64),
|
||||
}
|
||||
processed = ctx.robot_action_processor((robot_action, obs))
|
||||
robot.send_action(processed)
|
||||
last_action = processed
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
if stream_online:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
record_tick += 1
|
||||
else:
|
||||
# Sync inference
|
||||
if interpolator.needs_new_action():
|
||||
device = get_safe_torch_device(cfg.device)
|
||||
action_tensor = predict_action(
|
||||
observation=obs_frame,
|
||||
policy=ctx.policy,
|
||||
device=device,
|
||||
preprocessor=ctx.preprocessor,
|
||||
postprocessor=ctx.postprocessor,
|
||||
use_amp=ctx.policy.config.use_amp,
|
||||
task=task_str,
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
robot_action = make_robot_action(action_tensor, dataset.features)
|
||||
action_t = torch.tensor([robot_action[k] for k in dataset_action_keys])
|
||||
interpolator.add(action_t)
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
robot_action = {k: interp[i].item() for i, k in enumerate(dataset_action_keys)}
|
||||
processed = ctx.robot_action_processor((robot_action, obs))
|
||||
robot.send_action(processed)
|
||||
last_action = processed
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
if record_tick % record_stride == 0:
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
if stream_online:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
record_tick += 1
|
||||
if stream_online:
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
frame_buffer.append(frame)
|
||||
record_tick += 1
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
|
||||
@@ -16,21 +16,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Event as ThreadingEvent
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..configs import HighlightStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine, _resolve_action_key_order
|
||||
from ..inference import InferenceEngine
|
||||
from ..ring_buffer import RolloutRingBuffer
|
||||
from . import RolloutStrategy
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,13 +55,14 @@ class HighlightStrategy(RolloutStrategy):
|
||||
def __init__(self, config: HighlightStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._ring: RolloutRingBuffer | None = None
|
||||
self._listener = None
|
||||
self._save_requested = False
|
||||
self._recording_live = False
|
||||
self._save_requested = ThreadingEvent()
|
||||
self._recording_live = ThreadingEvent()
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
@@ -73,7 +75,6 @@ class HighlightStrategy(RolloutStrategy):
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
interpolator=interpolator,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
@@ -97,17 +98,12 @@ class HighlightStrategy(RolloutStrategy):
|
||||
cfg = ctx.cfg
|
||||
robot = ctx.robot_wrapper
|
||||
dataset = ctx.dataset
|
||||
action_keys = ctx.action_keys
|
||||
ring = self._ring
|
||||
interpolator = self._interpolator
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
ordered_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
action_keys,
|
||||
)
|
||||
ordered_keys = ctx.ordered_action_keys
|
||||
features = dataset.features
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
@@ -126,70 +122,58 @@ class HighlightStrategy(RolloutStrategy):
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.robot_observation_processor(obs)
|
||||
action_dict = None
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {
|
||||
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
|
||||
}
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
else:
|
||||
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
|
||||
# Build frame for ring buffer / live recording
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
|
||||
# Handle save key toggle
|
||||
if self._save_requested:
|
||||
self._save_requested = False
|
||||
if not self._recording_live:
|
||||
if self._save_requested.is_set():
|
||||
self._save_requested.clear()
|
||||
if not self._recording_live.is_set():
|
||||
logger.info(
|
||||
"Flushing ring buffer (%d frames) + starting live recording", len(ring)
|
||||
)
|
||||
for buffered_frame in ring.drain():
|
||||
dataset.add_frame(buffered_frame)
|
||||
self._recording_live = True
|
||||
self._recording_live.set()
|
||||
else:
|
||||
# Save current frame as the last frame of the episode
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
logger.info("Episode saved")
|
||||
self._recording_live = False
|
||||
self._recording_live.clear()
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
if self._recording_live:
|
||||
if self._recording_live.is_set():
|
||||
dataset.add_frame(frame)
|
||||
else:
|
||||
# Current frame goes into the ring buffer for next potential save.
|
||||
ring.append(frame)
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
@@ -197,11 +181,9 @@ class HighlightStrategy(RolloutStrategy):
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
if self._recording_live:
|
||||
try:
|
||||
if self._recording_live.is_set():
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
if self._engine is not None:
|
||||
@@ -237,13 +219,11 @@ class HighlightStrategy(RolloutStrategy):
|
||||
save_key = self.config.save_key
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
if hasattr(key, "char") and key.char == save_key:
|
||||
self._save_requested = True
|
||||
self._save_requested.set()
|
||||
elif key == keyboard.Key.esc:
|
||||
self._save_requested = False
|
||||
except Exception:
|
||||
pass
|
||||
self._save_requested.clear()
|
||||
|
||||
self._listener = keyboard.Listener(on_press=on_press)
|
||||
self._listener.start()
|
||||
|
||||
@@ -16,21 +16,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from lerobot.datasets import VideoEncodingManager
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
from lerobot.policies.utils import make_robot_action
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
from ..configs import SentryStrategyConfig
|
||||
from ..context import RolloutContext
|
||||
from ..inference import InferenceEngine, _resolve_action_key_order
|
||||
from . import RolloutStrategy
|
||||
from ..inference import InferenceEngine
|
||||
from . import RolloutStrategy, infer_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -55,10 +55,12 @@ class SentryStrategy(RolloutStrategy):
|
||||
def __init__(self, config: SentryStrategyConfig):
|
||||
super().__init__(config)
|
||||
self._engine: InferenceEngine | None = None
|
||||
self._interpolator: ActionInterpolator | None = None
|
||||
self._push_thread: Thread | None = None
|
||||
self._needs_push: bool = False
|
||||
|
||||
def setup(self, ctx: RolloutContext) -> None:
|
||||
interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier)
|
||||
|
||||
self._engine = InferenceEngine(
|
||||
policy=ctx.policy,
|
||||
@@ -71,7 +73,6 @@ class SentryStrategy(RolloutStrategy):
|
||||
task=ctx.cfg.task,
|
||||
fps=ctx.cfg.fps,
|
||||
device=ctx.cfg.device,
|
||||
interpolator=interpolator,
|
||||
use_torch_compile=ctx.cfg.use_torch_compile,
|
||||
compile_warmup_inferences=ctx.cfg.compile_warmup_inferences,
|
||||
)
|
||||
@@ -87,16 +88,11 @@ class SentryStrategy(RolloutStrategy):
|
||||
cfg = ctx.cfg
|
||||
robot = ctx.robot_wrapper
|
||||
dataset = ctx.dataset
|
||||
action_keys = ctx.action_keys
|
||||
interpolator = self._interpolator
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(cfg.fps)
|
||||
|
||||
policy_action_names = getattr(cfg.policy, "action_feature_names", None)
|
||||
ordered_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
action_keys,
|
||||
)
|
||||
ordered_keys = ctx.ordered_action_keys
|
||||
features = dataset.features
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
@@ -117,45 +113,31 @@ class SentryStrategy(RolloutStrategy):
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = ctx.robot_observation_processor(obs)
|
||||
action_dict = None
|
||||
|
||||
if engine.is_rtc:
|
||||
engine.update_observation(obs_processed)
|
||||
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
if cfg.use_torch_compile and not engine.compile_warmup_done.is_set():
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
continue
|
||||
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if cfg.use_torch_compile and not warmup_flushed:
|
||||
engine.reset()
|
||||
interpolator.reset()
|
||||
warmup_flushed = True
|
||||
if engine.is_rtc:
|
||||
engine.resume()
|
||||
|
||||
if interpolator.needs_new_action():
|
||||
action_tensor = engine.consume_rtc_action()
|
||||
if action_tensor is not None:
|
||||
interpolator.add(action_tensor.cpu())
|
||||
|
||||
interp = interpolator.get()
|
||||
if interp is not None:
|
||||
action_dict = {
|
||||
k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)
|
||||
}
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
else:
|
||||
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
|
||||
action_tensor = engine.get_action_sync(obs_frame)
|
||||
action_dict = make_robot_action(action_tensor, ctx.dataset_features)
|
||||
processed = ctx.robot_action_processor((action_dict, obs))
|
||||
robot.send_action(processed)
|
||||
action_dict = infer_action(
|
||||
engine, obs_processed, obs, ctx, interpolator, ordered_keys, features
|
||||
)
|
||||
|
||||
# Record frame
|
||||
if action_dict is not None:
|
||||
obs_frame = build_dataset_frame(ctx.dataset_features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(ctx.dataset_features, action_dict, prefix=ACTION)
|
||||
obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(features, action_dict, prefix=ACTION)
|
||||
frame = {**obs_frame, **action_frame, "task": task_str}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
@@ -164,6 +146,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
if elapsed >= self.config.episode_duration_s:
|
||||
dataset.save_episode()
|
||||
episodes_since_push += 1
|
||||
self._needs_push = True
|
||||
logger.info("Episode saved (total: %d)", dataset.num_episodes)
|
||||
|
||||
if episodes_since_push >= self.config.upload_every_n_episodes:
|
||||
@@ -181,26 +164,27 @@ class SentryStrategy(RolloutStrategy):
|
||||
precise_sleep(sleep_t)
|
||||
|
||||
finally:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
dataset.save_episode()
|
||||
except Exception:
|
||||
pass
|
||||
self._needs_push = True
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
if self._engine is not None:
|
||||
self._engine.stop()
|
||||
|
||||
# Wait for any in-flight background push
|
||||
if self._push_thread is not None and self._push_thread.is_alive():
|
||||
self._push_thread.join(timeout=60)
|
||||
|
||||
if ctx.dataset is not None:
|
||||
ctx.dataset.finalize()
|
||||
if ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
|
||||
# Only push if there are unsaved changes since last background push
|
||||
if self._needs_push and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub:
|
||||
ctx.dataset.push_to_hub(
|
||||
tags=ctx.cfg.dataset.tags,
|
||||
private=ctx.cfg.dataset.private,
|
||||
)
|
||||
|
||||
if self._push_thread is not None and self._push_thread.is_alive():
|
||||
self._push_thread.join(timeout=60)
|
||||
|
||||
if ctx.robot.is_connected:
|
||||
ctx.robot.disconnect()
|
||||
if ctx.teleop is not None and ctx.teleop.is_connected:
|
||||
@@ -219,6 +203,7 @@ class SentryStrategy(RolloutStrategy):
|
||||
tags=cfg.dataset.tags if cfg.dataset else None,
|
||||
private=cfg.dataset.private if cfg.dataset else False,
|
||||
)
|
||||
self._needs_push = False
|
||||
logger.info("Background push to hub complete")
|
||||
except Exception as e:
|
||||
logger.error("Background push failed: %s", e)
|
||||
|
||||
Reference in New Issue
Block a user