mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
feat(policies): implement RTC to EVO1
This commit is contained in:
+23
-33
@@ -32,9 +32,9 @@ The broader EVO1 project may include additional training scripts and dataset too
|
||||
pip install -e ".[evo1,libero]"
|
||||
```
|
||||
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available, but reproducing the official LIBERO checkpoint conversion result below requires the same FlashAttention path used by the original EVO1 checkpoint.
|
||||
3. Install a `flash-attn` wheel only if it is compatible with your Python, PyTorch, CUDA, and GPU stack. EVO1 falls back to standard attention when `flash_attn` is not available.
|
||||
|
||||
EVO1 uses the native Hugging Face `transformers` InternVL implementation (no `trust_remote_code`), so `policy.vlm_model_name` must point to a natively converted checkpoint such as `OpenGVLab/InternVL3-1B-hf` (note the `-hf` suffix; the original `OpenGVLab/InternVL3-1B` repo requires remote code and cannot be loaded). The first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
EVO1 uses the native Hugging Face `transformers` InternVL implementation, so `policy.vlm_model_name` must point to a natively converted checkpoint such as `OpenGVLab/InternVL3-1B-hf` (note the `-hf` suffix). The first run may download the configured VLM checkpoint unless `policy.vlm_model_name` points to a local model directory.
|
||||
|
||||
## Data Requirements
|
||||
|
||||
@@ -67,12 +67,6 @@ Once a LeRobot-format EVO1 checkpoint is available, load it with:
|
||||
policy.path=your-org/your-evo1-checkpoint
|
||||
```
|
||||
|
||||
The converted LIBERO checkpoint used for this PR is available at:
|
||||
|
||||
```python
|
||||
policy.path=javadcc/evo1-libero-lerobot
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Stage 1
|
||||
@@ -143,39 +137,35 @@ every finetuning flag.
|
||||
| `policy.binarize_gripper` | `false` | Binarizes the postprocessed gripper channel for LIBERO-style eval |
|
||||
| `policy.task_field` | `task` | Batch field used as the language prompt |
|
||||
|
||||
## Inference
|
||||
|
||||
Try it out with a trained EVO1 checkpoint:
|
||||
|
||||
```bash
|
||||
lerobot-rollout \
|
||||
--policy.path=your-org/your-evo1-checkpoint \
|
||||
--inference.type=rtc \ # optional
|
||||
...
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
### LIBERO Object Checkpoint Conversion
|
||||
|
||||
The checkpoint [javadcc/evo1-libero-lerobot](https://huggingface.co/javadcc/evo1-libero-lerobot)
|
||||
is the LeRobot-format conversion of the official EVO1 LIBERO checkpoint. The conversion was checked against
|
||||
the official EVO1 checkpoint with the same LIBERO Object initial states and action postprocessing.
|
||||
### LIBERO Evaluation
|
||||
|
||||
> [!NOTE]
|
||||
> This checkpoint is currently hosted in a community namespace and the upstream-to-LeRobot weight
|
||||
> conversion script is not part of this integration; a `lerobot`-hosted copy with a pinned revision
|
||||
> and the conversion tooling are planned follow-ups.
|
||||
> Benchmark results for a `lerobot`-hosted LIBERO checkpoint trained with this implementation
|
||||
> will be added once training completes.
|
||||
|
||||
| Checkpoint | Suite | Episodes | Success Rate |
|
||||
| ---------------------------- | --------------- | ---------------- | ------------ |
|
||||
| Official EVO1 checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
| LeRobot converted checkpoint | `libero_object` | 10, one per task | 100% |
|
||||
|
||||
For a fixed `libero_object` rollout, the official checkpoint and LeRobot checkpoint produced identical
|
||||
pixel embeddings, VLM fused tokens, normalized actions, and denormalized actions for the checked action step
|
||||
(`max_abs_diff=0.0`).
|
||||
|
||||
The published checkpoint expects the raw LIBERO camera feature names
|
||||
`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`. The official EVO1 LIBERO
|
||||
rollout protocol also replans every 14 actions and binarizes the gripper command before stepping the simulator.
|
||||
The EVO1 policy postprocessor can crop the padded 24D action back to the 7D LIBERO action space and apply that
|
||||
gripper binarization. To run the converted checkpoint with LeRobot LIBERO evaluation for the same
|
||||
one-episode-per-task setting, keep the raw camera names instead of the default `image`/`image2` mapping, enable
|
||||
FlashAttention, and set the LIBERO action postprocessing flags:
|
||||
The official EVO1 LIBERO rollout protocol uses the raw LIBERO camera feature names
|
||||
(`observation.images.agentview_image` and `observation.images.robot0_eye_in_hand_image`), replans every
|
||||
14 actions, and binarizes the gripper command before stepping the simulator. The EVO1 policy postprocessor
|
||||
can crop the padded 24D action back to the 7D LIBERO action space and apply that gripper binarization. To
|
||||
evaluate a LIBERO checkpoint under the same one-episode-per-task setting, keep the raw camera names instead
|
||||
of the default `image`/`image2` mapping and set the LIBERO action postprocessing flags:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=javadcc/evo1-libero-lerobot \
|
||||
--policy.path=your-org/your-evo1-libero-checkpoint \
|
||||
--policy.vlm_model_name=OpenGVLab/InternVL3-1B-hf \
|
||||
--policy.device=cuda \
|
||||
--policy.use_flash_attn=true \
|
||||
|
||||
@@ -23,6 +23,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineAnnealingWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -93,6 +95,10 @@ class Evo1Config(PreTrainedConfig):
|
||||
embodiment_id_field: str | None = None
|
||||
default_embodiment_id: int = 0
|
||||
|
||||
# Real-Time Chunking guidance for asynchronous inference (lerobot-rollout --inference.type=rtc
|
||||
# sets this and calls init_rtc_processor()); None disables RTC.
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
|
||||
@@ -28,6 +28,8 @@ class Evo1Model(nn.Module):
|
||||
self.config = config
|
||||
self._device = config.device
|
||||
self.return_cls_only = config.return_cls_only
|
||||
# Set by Evo1Policy.init_rtc_processor() when config.rtc_config is provided.
|
||||
self.rtc_processor = None
|
||||
|
||||
# Gradient checkpointing only pays off when the VLM is actually being trained; keep it off
|
||||
# whenever every VLM branch is frozen so the frozen forward stays cheap.
|
||||
@@ -130,6 +132,9 @@ class Evo1Model(nn.Module):
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
):
|
||||
if actions_gt is None:
|
||||
return self.action_head.get_action(
|
||||
@@ -138,6 +143,10 @@ class Evo1Model(nn.Module):
|
||||
action_mask=action_mask,
|
||||
embodiment_id=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=execution_horizon,
|
||||
rtc_processor=self.rtc_processor,
|
||||
)
|
||||
return self.action_head(
|
||||
fused_tokens,
|
||||
@@ -156,8 +165,21 @@ class Evo1Model(nn.Module):
|
||||
action_mask: torch.Tensor | None = None,
|
||||
embodiment_ids: torch.Tensor | None = None,
|
||||
context_mask: torch.Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
):
|
||||
return self.predict_action(fused_tokens, state, actions_gt, action_mask, embodiment_ids, context_mask)
|
||||
return self.predict_action(
|
||||
fused_tokens,
|
||||
state,
|
||||
actions_gt,
|
||||
action_mask,
|
||||
embodiment_ids,
|
||||
context_mask,
|
||||
inference_delay,
|
||||
prev_chunk_left_over,
|
||||
execution_horizon,
|
||||
)
|
||||
|
||||
def _set_module_trainable(self, module: nn.Module, trainable: bool):
|
||||
for param in module.parameters():
|
||||
|
||||
@@ -397,6 +397,10 @@ class FlowmatchingActionHead(nn.Module):
|
||||
embodiment_id: torch.LongTensor = None,
|
||||
action_mask: torch.Tensor = None,
|
||||
context_mask: torch.Tensor = None,
|
||||
inference_delay: int | None = None,
|
||||
prev_chunk_left_over: torch.Tensor | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
rtc_processor=None,
|
||||
):
|
||||
batch_size = fused_tokens.size(0)
|
||||
device = fused_tokens.device
|
||||
@@ -408,11 +412,7 @@ class FlowmatchingActionHead(nn.Module):
|
||||
per_action_dim = self.per_action_dim
|
||||
|
||||
action = torch.rand(batch_size, action_dim_total, device=device, dtype=context_tokens.dtype) * 2 - 1
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
action_seq = action.view(batch_size, self.horizon, per_action_dim)
|
||||
action_mask = self._expand_action_mask(
|
||||
action_mask,
|
||||
batch_size=batch_size,
|
||||
@@ -430,36 +430,46 @@ class FlowmatchingActionHead(nn.Module):
|
||||
raise ValueError(f"num_inference_timesteps must be positive, got {num_steps}")
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
use_rtc = rtc_processor is not None and (
|
||||
inference_delay is not None or prev_chunk_left_over is not None
|
||||
)
|
||||
|
||||
def predict_velocity(seq: torch.Tensor, step_time_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Predict the masked flow velocity (x1 - x0 convention) for one integration step."""
|
||||
seq = seq * action_mask
|
||||
action_tokens = self._project_actions(seq, embodiment_id).to(dtype=target_dtype)
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, step_time_emb, key_padding_mask)
|
||||
x = self.norm_out(x)
|
||||
x_pooled = self.seq_pool_proj(x.reshape(batch_size, -1)) if self.horizon > 1 else x.squeeze(1)
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
return pred.view(batch_size, self.horizon, per_action_dim) * action_mask
|
||||
|
||||
for i in range(num_steps):
|
||||
t = i / num_steps
|
||||
time_index = min(int(t * 999), 999)
|
||||
time_emb = (
|
||||
self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=context_tokens.dtype)
|
||||
)
|
||||
time_emb = self.time_pos_enc(1000)[:, time_index, :].to(device).squeeze(0).to(dtype=target_dtype)
|
||||
time_emb = time_emb.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
action_tokens = self._project_actions(action_seq, embodiment_id).to(dtype=target_dtype)
|
||||
time_emb = time_emb.to(dtype=target_dtype)
|
||||
|
||||
x = action_tokens
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context_tokens, time_emb, key_padding_mask)
|
||||
x = self.norm_out(x)
|
||||
|
||||
if self.horizon > 1:
|
||||
x_flat = x.reshape(batch_size, -1)
|
||||
x_pooled = self.seq_pool_proj(x_flat)
|
||||
if use_rtc:
|
||||
# RTCProcessor assumes the pi0 flow convention: its `time` runs 1 -> 0 and the
|
||||
# clean-action estimate is x1 = x_t - time * v. EVO1 integrates t: 0 -> 1 with
|
||||
# velocity v = x1 - x0 (so x1 = x_t + (1 - t) * v); passing time = 1 - t and
|
||||
# flipping the velocity sign in both directions maps one convention onto the other.
|
||||
guided = rtc_processor.denoise_step(
|
||||
x_t=action_seq,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=1.0 - t,
|
||||
original_denoise_step_partial=lambda seq, emb=time_emb: -predict_velocity(seq, emb),
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
velocity = -guided
|
||||
else:
|
||||
x_pooled = x.squeeze(1)
|
||||
velocity = predict_velocity(action_seq, time_emb)
|
||||
|
||||
pred = self.mlp_head(x_pooled, embodiment_id)
|
||||
action = action + dt * pred
|
||||
action_seq = (
|
||||
action.view(batch_size, self.horizon, per_action_dim)
|
||||
if self.horizon > 1
|
||||
else action.view(batch_size, 1, per_action_dim)
|
||||
)
|
||||
action_seq = action_seq + dt * velocity
|
||||
|
||||
action_seq = action_seq * action_mask
|
||||
return action_seq.reshape(batch_size, -1)
|
||||
|
||||
@@ -18,6 +18,7 @@ import builtins
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Unpack
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -26,10 +27,17 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
|
||||
from ..rtc.modeling_rtc import RTCProcessor
|
||||
from .configuration_evo1 import Evo1Config
|
||||
from .evo1_model import Evo1Model
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
class Evo1Policy(PreTrainedPolicy):
|
||||
config_class = Evo1Config
|
||||
name = "evo1"
|
||||
@@ -47,8 +55,25 @@ class Evo1Policy(PreTrainedPolicy):
|
||||
self.model = Evo1Model(config, vlm_hub_kwargs=vlm_hub_kwargs)
|
||||
self.model.set_finetune_flags()
|
||||
self._keep_frozen_embedder_eval()
|
||||
self.init_rtc_processor()
|
||||
self.reset()
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Create the RTC processor when config.rtc_config is set.
|
||||
|
||||
The RTC rollout backend assigns config.rtc_config after loading the policy and re-invokes
|
||||
this method.
|
||||
"""
|
||||
self.rtc_processor = None
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
model = getattr(self, "model", None)
|
||||
if model is not None:
|
||||
model.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
@@ -457,11 +482,15 @@ class Evo1Policy(PreTrainedPolicy):
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
if kwargs.get("inference_delay") is not None or kwargs.get("prev_chunk_left_over") is not None:
|
||||
raise NotImplementedError(
|
||||
"EVO1 does not implement real-time-chunking (RTC) inference; "
|
||||
"use the synchronous inference backend."
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
if (inference_delay is not None or prev_chunk_left_over is not None) and not self._rtc_enabled():
|
||||
raise RuntimeError(
|
||||
"Received RTC arguments but RTC is not configured for this EVO1 policy: set "
|
||||
"config.rtc_config and call init_rtc_processor() (lerobot-rollout does this for "
|
||||
"--inference.type=rtc)."
|
||||
)
|
||||
self.eval()
|
||||
|
||||
@@ -470,6 +499,8 @@ class Evo1Policy(PreTrainedPolicy):
|
||||
states, _state_mask = self._prepare_state(batch)
|
||||
embodiment_ids = self._get_embodiment_ids(batch, states.shape[0])
|
||||
action_mask = self._prepare_inference_action_mask(states.shape[0])
|
||||
if prev_chunk_left_over is not None:
|
||||
prev_chunk_left_over = prev_chunk_left_over.to(device=self._device)
|
||||
|
||||
with self._maybe_autocast():
|
||||
fused_tokens, context_mask = self._compute_fused_tokens(prompts, image_batches, image_masks)
|
||||
@@ -479,12 +510,18 @@ class Evo1Policy(PreTrainedPolicy):
|
||||
action_mask=action_mask,
|
||||
embodiment_ids=embodiment_ids,
|
||||
context_mask=context_mask,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
actions = actions.view(states.shape[0], self.config.chunk_size, self.config.max_action_dim)
|
||||
return actions.to(dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
action_chunk = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
|
||||
@@ -39,6 +39,8 @@ from lerobot.policies.evo1.processor_evo1 import (
|
||||
reconcile_evo1_processors,
|
||||
)
|
||||
from lerobot.policies.factory import get_policy_class, make_policy_config
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor import (
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
@@ -98,6 +100,7 @@ class DummyEvo1Model(nn.Module):
|
||||
action_mask=None,
|
||||
embodiment_ids=None,
|
||||
context_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
batch_size = fused_tokens.shape[0]
|
||||
if actions_gt is None:
|
||||
@@ -122,6 +125,7 @@ class ChunkCountingDummyModel(DummyEvo1Model):
|
||||
action_mask=None,
|
||||
embodiment_ids=None,
|
||||
context_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
if actions_gt is not None:
|
||||
return super().forward(fused_tokens, state, actions_gt, action_mask, embodiment_ids, context_mask)
|
||||
@@ -330,19 +334,23 @@ def test_evo1_model_uses_image_resolution_and_trainable_checkpointing(monkeypatc
|
||||
assert captured["enable_gradient_checkpointing"] is True
|
||||
|
||||
|
||||
class FakeInternVLModel(nn.Module):
|
||||
"""Minimal stand-in with the native HF InternVL submodule layout."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.language_model = nn.Linear(2, 2)
|
||||
self.vision_tower = nn.Linear(2, 2)
|
||||
self.multi_modal_projector = nn.Linear(2, 2)
|
||||
|
||||
|
||||
class FakeEmbedder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.model = FakeInternVLModel()
|
||||
|
||||
|
||||
def test_set_finetune_flags_targets_native_hf_internvl_submodules(monkeypatch):
|
||||
class FakeInternVLModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.language_model = nn.Linear(2, 2)
|
||||
self.vision_tower = nn.Linear(2, 2)
|
||||
self.multi_modal_projector = nn.Linear(2, 2)
|
||||
|
||||
class FakeEmbedder(nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.model = FakeInternVLModel()
|
||||
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
|
||||
|
||||
stage2_model = evo1_model.Evo1Model(make_config(training_stage="stage2"))
|
||||
@@ -550,13 +558,82 @@ def test_evo1_select_action_queue_orders_steps_and_repredicts(monkeypatch):
|
||||
assert policy.model.chunks_predicted == 2
|
||||
|
||||
|
||||
def test_evo1_predict_action_chunk_rejects_rtc_kwargs(monkeypatch):
|
||||
def test_evo1_predict_action_chunk_rejects_rtc_kwargs_without_rtc_config(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
with pytest.raises(NotImplementedError, match="RTC"):
|
||||
with pytest.raises(RuntimeError, match="RTC"):
|
||||
policy.predict_action_chunk(make_batch(include_action=False), inference_delay=2)
|
||||
|
||||
|
||||
def test_evo1_rtc_processor_wiring(monkeypatch):
|
||||
monkeypatch.setattr(evo1_model, "InternVL3Embedder", FakeEmbedder)
|
||||
policy = modeling_evo1.Evo1Policy(make_config())
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
|
||||
# The RTC rollout backend assigns rtc_config after loading and re-inits the processor.
|
||||
policy.config.rtc_config = RTCConfig(execution_horizon=CHUNK_SIZE)
|
||||
policy.init_rtc_processor()
|
||||
assert isinstance(policy.rtc_processor, RTCProcessor)
|
||||
assert policy.model.rtc_processor is policy.rtc_processor
|
||||
|
||||
# RTC drives predict_action_chunk directly; the select_action queue path is unsupported.
|
||||
with pytest.raises(AssertionError, match="select_action"):
|
||||
policy.select_action(make_batch(include_action=False))
|
||||
|
||||
|
||||
def test_flowmatching_rtc_guidance_pulls_prefix_toward_previous_chunk():
|
||||
head = make_flowmatching_head(num_inference_timesteps=16)
|
||||
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
|
||||
fused = torch.randn(2, 4, EMBED_DIM)
|
||||
state = torch.randn(2, STATE_DIM)
|
||||
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
prev_chunk = torch.tensor([0.7, -0.4, 0.2]).expand(2, CHUNK_SIZE, ACTION_DIM).contiguous()
|
||||
|
||||
torch.manual_seed(0)
|
||||
unguided = head.get_action(fused, state=state, action_mask=action_mask)
|
||||
unguided = unguided.view(2, CHUNK_SIZE, ACTION_DIM)
|
||||
torch.manual_seed(0)
|
||||
guided = head.get_action(
|
||||
fused,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
inference_delay=1,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
rtc_processor=processor,
|
||||
)
|
||||
guided = guided.view(2, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
# The frozen prefix (first inference_delay steps) must land far closer to the previous chunk
|
||||
# than the unguided sample from the same noise does.
|
||||
guided_dist = (guided[:, 0] - prev_chunk[:, 0]).abs().mean()
|
||||
unguided_dist = (unguided[:, 0] - prev_chunk[:, 0]).abs().mean()
|
||||
assert guided_dist < 0.5 * unguided_dist
|
||||
assert torch.isfinite(guided).all()
|
||||
|
||||
|
||||
def test_flowmatching_rtc_first_chunk_without_leftover_matches_unguided():
|
||||
head = make_flowmatching_head(num_inference_timesteps=4)
|
||||
processor = RTCProcessor(RTCConfig(execution_horizon=CHUNK_SIZE))
|
||||
fused = torch.randn(2, 4, EMBED_DIM)
|
||||
state = torch.randn(2, STATE_DIM)
|
||||
action_mask = torch.ones(2, ACTION_DIM, dtype=torch.bool)
|
||||
|
||||
torch.manual_seed(0)
|
||||
unguided = head.get_action(fused, state=state, action_mask=action_mask)
|
||||
torch.manual_seed(0)
|
||||
first_chunk = head.get_action(
|
||||
fused,
|
||||
state=state,
|
||||
action_mask=action_mask,
|
||||
inference_delay=2,
|
||||
prev_chunk_left_over=None,
|
||||
rtc_processor=processor,
|
||||
)
|
||||
|
||||
assert torch.allclose(unguided, first_chunk)
|
||||
|
||||
|
||||
def test_evo1_missing_configured_camera_needs_empty_cameras_budget(monkeypatch):
|
||||
monkeypatch.setattr(modeling_evo1, "Evo1Model", DummyEvo1Model)
|
||||
batch = make_batch(include_action=False) # only provides the front camera
|
||||
|
||||
Reference in New Issue
Block a user