From 2ec82c68b4de94d14c0bd193cc6c1ab0091b5924 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Thu, 18 Jun 2026 09:44:47 +0000 Subject: [PATCH] removing temporary debug code --- .../policies/fastwam/modeling_fastwam.py | 141 +----------------- 1 file changed, 1 insertion(+), 140 deletions(-) diff --git a/src/lerobot/policies/fastwam/modeling_fastwam.py b/src/lerobot/policies/fastwam/modeling_fastwam.py index 2f4e86229..0c99613f3 100644 --- a/src/lerobot/policies/fastwam/modeling_fastwam.py +++ b/src/lerobot/policies/fastwam/modeling_fastwam.py @@ -15,7 +15,6 @@ from __future__ import annotations import logging -import os from collections import deque from pathlib import Path from typing import Any @@ -35,16 +34,6 @@ from .wan_components import ( ) from .wan_video_dit import WanVideoDiT -# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first -# eval episode's action chunks through `infer_joint` so the predicted video latents are -# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path). -_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1" -# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1 -# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/ -# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the -# (now-fixed) bug where the model ran on the un-subsampled num_video_frames. -_DEBUG_PRED_RATE_DIV = 1 - class FastWAMPolicy(PreTrainedPolicy): """LeRobot policy wrapper for FastWAM. @@ -84,17 +73,6 @@ class FastWAMPolicy(PreTrainedPolicy): if "video" in layer.blocks: layer.blocks["video"].requires_grad_(False) self.reset() - # TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()` - # counts only eval-rollout resets (one per episode), not this __init__ one. - self._debug_constructed = True - self._debug_episode_index = -1 - self._debug_seen_tasks: set[str] = set() - self._debug_capturing = False - self._debug_episode_started = False - self._debug_episode_task = "" - self._debug_step_in_chunk = 0 - self._debug_last_video: list | None = None - self._debug_pairs: list = [] @classmethod def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool): @@ -179,19 +157,6 @@ class FastWAMPolicy(PreTrainedPolicy): def reset(self) -> None: self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps) - # TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's - # true-vs-pred video if it was a captured one (pairs accumulate only while - # capturing), then reset per-episode capture state. - if getattr(self, "_debug_constructed", False): - if _FASTWAM_DECODE_DEBUG and self._debug_pairs: - self._save_debug_video() - self._debug_episode_index += 1 - self._debug_capturing = False - self._debug_episode_started = False - self._debug_episode_task = "" - self._debug_step_in_chunk = 0 - self._debug_last_video = None - self._debug_pairs = [] def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Adapt a standard LeRobot batch to the FastWAM-native sample that @@ -259,21 +224,7 @@ class FastWAMPolicy(PreTrainedPolicy): self.eval() infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config) batch_size = _infer_kwargs_batch_size(infer_kwargs) - # TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task), - # run the joint video+action path so the predicted video is VAE-decoded; stash it - # so select_action can pair each predicted frame with the real obs that follows. - if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1: - out = self.model.infer_joint( - **infer_kwargs, - num_video_frames=self.config.model_video_frames, - test_action_with_infer_action=False, - ) - # The decoded rollout has model_video_frames frames spanning the full - # action_horizon (action_video_freq_ratio actions per frame); the per-step - # pairing indexes into it, so keep all frames. - self._debug_last_video = out["video"] - action = _action_from_model_output(out) - elif batch_size == 1: + if batch_size == 1: action = _action_from_model_output(self.model.infer_action(**infer_kwargs)) else: action = torch.cat( @@ -292,101 +243,11 @@ class FastWAMPolicy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor: self.eval() - # TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide - # whether to capture: yes iff this episode's task hasn't been captured yet (so we - # get the first episode of every task). - if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started: - self._debug_episode_started = True - task = self._debug_task_name(batch) - if task not in self._debug_seen_tasks: - self._debug_seen_tasks.add(task) - self._debug_capturing = True - self._debug_episode_task = task - capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing if len(self._action_queue) == 0: actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps] self._action_queue.extend(actions.transpose(0, 1)) - if capturing: - self._debug_step_in_chunk = 0 # a fresh chunk was just predicted - if capturing: - self._debug_capture_pair(batch) - self._debug_step_in_chunk += 1 return self._action_queue.popleft() - # ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ---- - @staticmethod - def _debug_task_name(batch: dict[str, Any]) -> str: - task = batch.get("task") - if isinstance(task, (list, tuple)): - task = task[0] if task else None - return str(task) if task else "no_task" - - def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None: - video = getattr(self, "_debug_last_video", None) - if not video: - return - real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1] - # Map env-step offset within the chunk to a predicted-frame index. The rollout has - # (model_video_frames - 1) transitions over action_horizon actions, so each env step - # advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio, - # e.g. 8/32 = 0.25 — one predicted frame per ~4 actions). - frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon) - idx = min( - int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)), - len(video) - 1, - ) - pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx]) - self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx) - self._debug_pairs.append(pair) - - @staticmethod - def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None: - from PIL import ImageDraw - - draw = ImageDraw.Draw(pair) - draw.text((3, 3), "true", fill=(255, 255, 0)) - draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0)) - - @staticmethod - def _debug_tensor_to_pil(image: Tensor): - from PIL import Image - - # `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE - # encode boundary), so map [0, 1] -> [0, 255] for display. - arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8) - return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy()) - - @staticmethod - def _debug_hstack(left, right): - from PIL import Image - - if right.height != left.height: - right = right.resize((round(right.width * left.height / right.height), left.height)) - canvas = Image.new("RGB", (left.width + right.width, left.height)) - canvas.paste(left, (0, 0)) - canvas.paste(right, (left.width, 0)) - return canvas - - def _save_debug_video(self) -> None: - import re - - import numpy as np - - from lerobot.utils.io_utils import write_video - - pairs = getattr(self, "_debug_pairs", None) - if not pairs: - return - out_dir = Path("outputs/fastwam_debug") - out_dir.mkdir(parents=True, exist_ok=True) - slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task" - path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4" - frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB - write_video(path, frames, fps=30) - logging.info( - "FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path - ) - def _build_core_model(self, config: FastWAMConfig) -> FastWAM: """Build the FastWAM core for training / inference.