mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
fix normalization in visualization
This commit is contained in:
@@ -43,6 +43,8 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||||
|
from lerobot.policies.sarm.sarm_utils import pad_state_to_max_dim
|
||||||
|
from lerobot.datasets.utils import load_stats
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
@@ -219,6 +221,8 @@ def run_inference(
|
|||||||
frames: np.ndarray,
|
frames: np.ndarray,
|
||||||
states: Optional[np.ndarray],
|
states: Optional[np.ndarray],
|
||||||
task_description: str,
|
task_description: str,
|
||||||
|
dataset_stats: dict | None = None,
|
||||||
|
state_key: str = "observation.state",
|
||||||
batch_size: int = 32
|
batch_size: int = 32
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
@@ -234,6 +238,8 @@ def run_inference(
|
|||||||
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
||||||
states: Joint states (num_frames, state_dim)
|
states: Joint states (num_frames, state_dim)
|
||||||
task_description: Task description text
|
task_description: Task description text
|
||||||
|
dataset_stats: Dataset statistics for state normalization (same as training)
|
||||||
|
state_key: Key for state in dataset_stats
|
||||||
batch_size: Batch size for processing slices
|
batch_size: Batch size for processing slices
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -258,6 +264,15 @@ def run_inference(
|
|||||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||||
if states is not None:
|
if states is not None:
|
||||||
state_embeddings = torch.tensor(states, dtype=torch.float32)
|
state_embeddings = torch.tensor(states, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Normalize states using dataset stats (same as training processor)
|
||||||
|
if dataset_stats is not None and state_key in dataset_stats:
|
||||||
|
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
|
||||||
|
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
|
||||||
|
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
|
||||||
|
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
|
||||||
|
else:
|
||||||
|
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
|
||||||
else:
|
else:
|
||||||
state_embeddings = None
|
state_embeddings = None
|
||||||
|
|
||||||
@@ -280,6 +295,8 @@ def run_inference(
|
|||||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||||
if state_embeddings is not None:
|
if state_embeddings is not None:
|
||||||
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
||||||
|
# Pad states to max_state_dim (same as training processor)
|
||||||
|
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
|
||||||
else:
|
else:
|
||||||
state_slices = None
|
state_slices = None
|
||||||
|
|
||||||
@@ -346,14 +363,13 @@ def visualize_predictions(
|
|||||||
else:
|
else:
|
||||||
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
||||||
|
|
||||||
if show_frames:
|
# Create figure with progress plot, stage plot, and sample frames
|
||||||
# Create figure with progress plot, stage plot, and sample frames
|
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
|
||||||
|
|
||||||
ax_progress = fig.add_subplot(gs[0])
|
ax_progress = fig.add_subplot(gs[0])
|
||||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||||
ax_frames = fig.add_subplot(gs[2])
|
ax_frames = fig.add_subplot(gs[2])
|
||||||
|
|
||||||
frame_indices = np.arange(len(progress_predictions))
|
frame_indices = np.arange(len(progress_predictions))
|
||||||
|
|
||||||
@@ -490,11 +506,20 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||||
|
state_key = args.state_key if args.state_key is not None else model.config.state_key
|
||||||
logger.info(f"Using image key: {image_key}")
|
logger.info(f"Using image key: {image_key}")
|
||||||
|
logger.info(f"Using state key: {state_key}")
|
||||||
|
|
||||||
|
# Load dataset stats for state normalization (same as training)
|
||||||
|
dataset_stats = load_stats(dataset.root)
|
||||||
|
if dataset_stats:
|
||||||
|
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
|
||||||
|
else:
|
||||||
|
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
|
||||||
|
|
||||||
# Load episode data
|
# Load episode data
|
||||||
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
|
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
|
||||||
dataset, args.episode_index, image_key, args.state_key
|
dataset, args.episode_index, image_key, state_key
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use task description from dataset if available, otherwise use command-line argument
|
# Use task description from dataset if available, otherwise use command-line argument
|
||||||
@@ -502,7 +527,10 @@ def main():
|
|||||||
logger.info(f"Using task description: '{task_description}'")
|
logger.info(f"Using task description: '{task_description}'")
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
progress_predictions, stage_predictions = run_inference(model, frames, states, task_description)
|
progress_predictions, stage_predictions = run_inference(
|
||||||
|
model, frames, states, task_description,
|
||||||
|
dataset_stats=dataset_stats, state_key=state_key
|
||||||
|
)
|
||||||
|
|
||||||
# Extract subtask names and temporal proportions from model config if available
|
# Extract subtask names and temporal proportions from model config if available
|
||||||
subtask_names = None
|
subtask_names = None
|
||||||
|
|||||||
Reference in New Issue
Block a user