mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
removing temporary debug code
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -35,16 +34,6 @@ from .wan_components import (
|
||||
)
|
||||
from .wan_video_dit import WanVideoDiT
|
||||
|
||||
# TEMPORARY DEBUG — revert before merge. When FASTWAM_DECODE_DEBUG=1, route the first
|
||||
# eval episode's action chunks through `infer_joint` so the predicted video latents are
|
||||
# decoded by the VAE and dumped as PNG frames (sanity-checks the diffusers decode path).
|
||||
_FASTWAM_DECODE_DEBUG = os.environ.get("FASTWAM_DECODE_DEBUG") == "1"
|
||||
# Debug viz knob: extra divisor on the predicted-frame advance per env step. Should be 1
|
||||
# now that the model emits model_video_frames (so frames_per_step = (model_video_frames-1)/
|
||||
# action_horizon already encodes the action_video_freq_ratio). Was 4 to compensate for the
|
||||
# (now-fixed) bug where the model ran on the un-subsampled num_video_frames.
|
||||
_DEBUG_PRED_RATE_DIV = 1
|
||||
|
||||
|
||||
class FastWAMPolicy(PreTrainedPolicy):
|
||||
"""LeRobot policy wrapper for FastWAM.
|
||||
@@ -84,17 +73,6 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
if "video" in layer.blocks:
|
||||
layer.blocks["video"].requires_grad_(False)
|
||||
self.reset()
|
||||
# TEMPORARY DEBUG — revert before merge. Mark construction done so `reset()`
|
||||
# counts only eval-rollout resets (one per episode), not this __init__ one.
|
||||
self._debug_constructed = True
|
||||
self._debug_episode_index = -1
|
||||
self._debug_seen_tasks: set[str] = set()
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video: list | None = None
|
||||
self._debug_pairs: list = []
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model, model_file: str, map_location: str, strict: bool):
|
||||
@@ -179,19 +157,6 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queue: deque[Tensor] = deque([], maxlen=self.config.n_action_steps)
|
||||
# TEMPORARY DEBUG — revert before merge. Flush the just-finished episode's
|
||||
# true-vs-pred video if it was a captured one (pairs accumulate only while
|
||||
# capturing), then reset per-episode capture state.
|
||||
if getattr(self, "_debug_constructed", False):
|
||||
if _FASTWAM_DECODE_DEBUG and self._debug_pairs:
|
||||
self._save_debug_video()
|
||||
self._debug_episode_index += 1
|
||||
self._debug_capturing = False
|
||||
self._debug_episode_started = False
|
||||
self._debug_episode_task = ""
|
||||
self._debug_step_in_chunk = 0
|
||||
self._debug_last_video = None
|
||||
self._debug_pairs = []
|
||||
|
||||
def _batch_to_training_sample(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Adapt a standard LeRobot batch to the FastWAM-native sample that
|
||||
@@ -259,21 +224,7 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
infer_kwargs = _batch_to_infer_kwargs(batch=batch, config=self.config)
|
||||
batch_size = _infer_kwargs_batch_size(infer_kwargs)
|
||||
# TEMPORARY DEBUG — revert before merge. On captured episodes (first of each task),
|
||||
# run the joint video+action path so the predicted video is VAE-decoded; stash it
|
||||
# so select_action can pair each predicted frame with the real obs that follows.
|
||||
if _FASTWAM_DECODE_DEBUG and getattr(self, "_debug_capturing", False) and batch_size == 1:
|
||||
out = self.model.infer_joint(
|
||||
**infer_kwargs,
|
||||
num_video_frames=self.config.model_video_frames,
|
||||
test_action_with_infer_action=False,
|
||||
)
|
||||
# The decoded rollout has model_video_frames frames spanning the full
|
||||
# action_horizon (action_video_freq_ratio actions per frame); the per-step
|
||||
# pairing indexes into it, so keep all frames.
|
||||
self._debug_last_video = out["video"]
|
||||
action = _action_from_model_output(out)
|
||||
elif batch_size == 1:
|
||||
if batch_size == 1:
|
||||
action = _action_from_model_output(self.model.infer_action(**infer_kwargs))
|
||||
else:
|
||||
action = torch.cat(
|
||||
@@ -292,101 +243,11 @@ class FastWAMPolicy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs: Any) -> Tensor:
|
||||
self.eval()
|
||||
# TEMPORARY DEBUG — revert before merge. On the first step of each episode, decide
|
||||
# whether to capture: yes iff this episode's task hasn't been captured yet (so we
|
||||
# get the first episode of every task).
|
||||
if _FASTWAM_DECODE_DEBUG and not self._debug_episode_started:
|
||||
self._debug_episode_started = True
|
||||
task = self._debug_task_name(batch)
|
||||
if task not in self._debug_seen_tasks:
|
||||
self._debug_seen_tasks.add(task)
|
||||
self._debug_capturing = True
|
||||
self._debug_episode_task = task
|
||||
capturing = _FASTWAM_DECODE_DEBUG and self._debug_capturing
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
if capturing:
|
||||
self._debug_step_in_chunk = 0 # a fresh chunk was just predicted
|
||||
if capturing:
|
||||
self._debug_capture_pair(batch)
|
||||
self._debug_step_in_chunk += 1
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# ---- TEMPORARY DEBUG (revert before merge): true-vs-predicted video capture ----
|
||||
@staticmethod
|
||||
def _debug_task_name(batch: dict[str, Any]) -> str:
|
||||
task = batch.get("task")
|
||||
if isinstance(task, (list, tuple)):
|
||||
task = task[0] if task else None
|
||||
return str(task) if task else "no_task"
|
||||
|
||||
def _debug_capture_pair(self, batch: dict[str, Tensor]) -> None:
|
||||
video = getattr(self, "_debug_last_video", None)
|
||||
if not video:
|
||||
return
|
||||
real = _input_image_from_batch(batch, self.config)[0] # [C,H,W] in [-1,1]
|
||||
# Map env-step offset within the chunk to a predicted-frame index. The rollout has
|
||||
# (model_video_frames - 1) transitions over action_horizon actions, so each env step
|
||||
# advances frames_per_step = (model_video_frames-1)/action_horizon frames (= 1/ratio,
|
||||
# e.g. 8/32 = 0.25 — one predicted frame per ~4 actions).
|
||||
frames_per_step = (self.config.model_video_frames - 1) / max(1, self.config.action_horizon)
|
||||
idx = min(
|
||||
int(round(self._debug_step_in_chunk * frames_per_step / _DEBUG_PRED_RATE_DIV)),
|
||||
len(video) - 1,
|
||||
)
|
||||
pair = self._debug_hstack(self._debug_tensor_to_pil(real), video[idx])
|
||||
self._debug_label_pair(pair, left_w=real.shape[-1], pred_idx=idx)
|
||||
self._debug_pairs.append(pair)
|
||||
|
||||
@staticmethod
|
||||
def _debug_label_pair(pair, left_w: int, pred_idx: int) -> None:
|
||||
from PIL import ImageDraw
|
||||
|
||||
draw = ImageDraw.Draw(pair)
|
||||
draw.text((3, 3), "true", fill=(255, 255, 0))
|
||||
draw.text((left_w + 3, 3), f"pred[t+{pred_idx}]", fill=(0, 255, 0))
|
||||
|
||||
@staticmethod
|
||||
def _debug_tensor_to_pil(image: Tensor):
|
||||
from PIL import Image
|
||||
|
||||
# `real` is the model input in [0, 1] (VISUAL is IDENTITY; the [-1,1] map lives at the VAE
|
||||
# encode boundary), so map [0, 1] -> [0, 255] for display.
|
||||
arr = (image.detach().float().clamp(0.0, 1.0) * 255.0).to(torch.uint8)
|
||||
return Image.fromarray(arr.cpu().permute(1, 2, 0).numpy())
|
||||
|
||||
@staticmethod
|
||||
def _debug_hstack(left, right):
|
||||
from PIL import Image
|
||||
|
||||
if right.height != left.height:
|
||||
right = right.resize((round(right.width * left.height / right.height), left.height))
|
||||
canvas = Image.new("RGB", (left.width + right.width, left.height))
|
||||
canvas.paste(left, (0, 0))
|
||||
canvas.paste(right, (left.width, 0))
|
||||
return canvas
|
||||
|
||||
def _save_debug_video(self) -> None:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.io_utils import write_video
|
||||
|
||||
pairs = getattr(self, "_debug_pairs", None)
|
||||
if not pairs:
|
||||
return
|
||||
out_dir = Path("outputs/fastwam_debug")
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
slug = re.sub(r"[^a-zA-Z0-9]+", "_", self._debug_episode_task).strip("_")[:40] or "task"
|
||||
path = out_dir / f"ep{self._debug_episode_index:03d}_{slug}_true_vs_pred.mp4"
|
||||
frames = [np.asarray(pair) for pair in pairs] # HWC uint8 RGB
|
||||
write_video(path, frames, fps=30)
|
||||
logging.info(
|
||||
"FASTWAM_DECODE_DEBUG: wrote %d-frame mp4 (left=true, right=pred) to %s", len(frames), path
|
||||
)
|
||||
|
||||
def _build_core_model(self, config: FastWAMConfig) -> FastWAM:
|
||||
"""Build the FastWAM core for training / inference.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user