removing temporary debug code

This commit is contained in:
Maxime Ellerbach
2026-06-18 09:44:47 +00:00
parent 5752558467
commit 2ec82c68b4
@@ -15,7 +15,6 @@
from __future__ import annotations
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any
@@ -35,16 +34,6 @@ from .wan_components import (
)
from .wan_video_dit import WanVideoDiT
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
# eval episode's action chunks through `infer_joint` so the predicted video latents are
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
_DEBUG_PRED_RATE_DIV = 1
class FastWAMPolicy(PreTrainedPolicy):
"""LeRobot policy wrapper for FastWAM.
@@ -84,17 +73,6 @@ class FastWAMPolicy(PreTrainedPolicy):
if "video" in layer.blocks:
layer.blocks["video"].requires_grad_(False)
self.reset()
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
# counts only eval-rollout resets (one per episode), not this __init__ one.
self._debug_constructed = True
self._debug_episode_index = -1
self._debug_seen_tasks: set[str] = set()
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video: list | None = None
self._debug_pairs: list = []
@classmethod
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
@@ -179,19 +157,6 @@ class FastWAMPolicy(PreTrainedPolicy):
def reset(self) -> None:
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
# true-vs-pred video if it was a captured one (pairs accumulate only while
# capturing), then reset per-episode capture state.
if getattr(self, "_debug_constructed", False):
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
self._save_debug_video()
self._debug_episode_index += 1
self._debug_capturing = False
self._debug_episode_started = False
self._debug_episode_task = ""
self._debug_step_in_chunk = 0
self._debug_last_video = None
self._debug_pairs = []
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
@@ -259,21 +224,7 @@ class FastWAMPolicy(PreTrainedPolicy):
self.eval()
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
batch_size = _infer_kwargs_batch_size(infer_kwargs)
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
# run the joint video+action path so the predicted video is VAE-decoded; stash it
# so select_action can pair each predicted frame with the real obs that follows.
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
out = self.model.infer_joint(
**infer_kwargs,
num_video_frames=self.config.model_video_frames,
test_action_with_infer_action=False,
)
# The decoded rollout has model_video_frames frames spanning the full
# action_horizon (action_video_freq_ratio actions per frame); the per-step
# pairing indexes into it, so keep all frames.
self._debug_last_video = out["video"]
action = _action_from_model_output(out)
elif batch_size == 1:
if batch_size == 1:
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
else:
action = torch.cat(
@@ -292,101 +243,11 @@ class FastWAMPolicy(PreTrainedPolicy):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
self.eval()
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
# get the first episode of every task).
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
self._debug_episode_started = True
task = self._debug_task_name(batch)
if task not in self._debug_seen_tasks:
self._debug_seen_tasks.add(task)
self._debug_capturing = True
self._debug_episode_task = task
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
if capturing:
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
if capturing:
self._debug_capture_pair(batch)
self._debug_step_in_chunk += 1
return self._action_queue.popleft()
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
@staticmethod
def _debug_task_name(batch: dict[str, Any]) -> str:
task = batch.get("task")
if isinstance(task, (list, tuple)):
task = task[0] if task else None
return str(task) if task else "no_task"
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
video = getattr(self, "_debug_last_video", None)
if not video:
return
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
idx = min(
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
len(video) - 1,
)
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
self._debug_pairs.append(pair)
@staticmethod
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
from PIL import ImageDraw
draw = ImageDraw.Draw(pair)
draw.text((3, 3), "true", fill=(255, 255, 0))
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
@staticmethod
def _debug_tensor_to_pil(image: Tensor):
from PIL import Image
# `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE
# encode boundary), so map [0, 1] -> [0, 255] for display.
arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8)
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
@staticmethod
def _debug_hstack(left, right):
from PIL import Image
if right.height != left.height:
right = right.resize((round(right.width * left.height / right.height), left.height))
canvas = Image.new("RGB", (left.width + right.width, left.height))
canvas.paste(left, (0, 0))
canvas.paste(right, (left.width, 0))
return canvas
def _save_debug_video(self) -> None:
import re
import numpy as np
from lerobot.utils.io_utils import write_video
pairs = getattr(self, "_debug_pairs", None)
if not pairs:
return
out_dir = Path("outputs/fastwam_debug")
out_dir.mkdir(parents=True, exist_ok=True)
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
write_video(path, frames, fps=30)
logging.info(
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
)
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
"""Build the FastWAM core for training / inference.