fix rewind discrepancies

This commit is contained in:
Pepijn
2025-11-18 16:09:16 +01:00
parent 0d84f4724d
commit 52b080fd8c
4 changed files with 61 additions and 57 deletions
@@ -342,7 +342,7 @@ class VideoAnnotator:
file_path,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
target_fps=2 # 2 FPS is good balance for VLM
target_fps=1
)
is_extracted = extracted_path != file_path
+47 -41
View File
@@ -33,9 +33,7 @@ def sample_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True,
remaining_length: int = None,
absolute_indices: torch.Tensor = None,
episode_length: int = None
remaining_length: int = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample or pad video features to a fixed length with progress targets.
@@ -52,18 +50,13 @@ def sample_video_feature(
This ensures all sequences show increasing progress from near-zero, regardless
of where they're sampled from in the episode.
Note: ReWiND uses consecutive frames loaded via observation_delta_indices.
When video_length > max_length, this function can subsample, but ReWiND
typically loads exactly max_length frames, so no subsampling occurs.
Uses original ReWiND sampling: random start/end points with minimum 3 frames.
Args:
video_feature: Video features tensor (num_frames, feature_dim)
max_length: Target sequence length
random_sample: If True, randomly sample frames. If False, uniformly sample consecutive frames.
ReWiND uses False to preserve temporal order.
remaining_length: Remaining trajectory length from first frame to episode end
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
episode_length: Total length of the episode [for fallback]
Returns:
Tuple of:
@@ -72,21 +65,29 @@ def sample_video_feature(
"""
video_length = len(video_feature)
# Original ReWiND sampling: random start/end with minimum 3 frames
if video_length > 3:
# Sample random start index (ensuring we can get at least 3 frames)
start_idx = random.randint(0, max(0, video_length - 3))
# Sample random end index (at least 3 frames after start, up to video_length)
end_idx = random.randint(min(start_idx + 3, video_length), video_length)
# Extract the sampled segment
video_feature = video_feature[start_idx:end_idx]
# Update video_length for the sampled segment
video_length = len(video_feature)
# Adjust remaining_length to be from start_idx to episode end
if remaining_length is not None:
# The remaining length should be from start_idx to episode end
# If we started at start_idx, we've already consumed start_idx frames
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
if remaining_length is not None:
# CORRECT: Use remaining length from first frame to episode end
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length
elif absolute_indices is not None and episode_length is not None:
# Fallback: Compute remaining length from first frame to episode end
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length_computed = episode_length - first_frame_idx
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length_computed
else:
# Fallback: linear progress (for inference/testing)
progress_targets = torch.linspace(1.0/video_length, 1.0, video_length)
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
progress_targets = progress_indices / remaining_length
if video_length < max_length:
# Pad with last frame
@@ -117,15 +118,13 @@ def sample_reverse_video_feature(
video_feature: torch.Tensor,
max_length: int = 32,
random_sample: bool = True,
remaining_length: int = None,
absolute_indices: torch.Tensor = None,
episode_length: int = None
remaining_length: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample video with reverse augmentation (video rewind) - ORIGINAL REWIND LOGIC.
This implements the EXACT video rewind augmentation from the original ReWiND paper:
1. Take forward sequence
1. Take forward sequence (sampled with random start/end, min 3 frames)
2. Append reversed frames from the END backwards
3. Progress increases then decreases (simulating task completion then failure)
@@ -139,8 +138,6 @@ def sample_reverse_video_feature(
max_length: Target sequence length
random_sample: If True, use random sampling for frame selection
remaining_length: Remaining trajectory length from first frame to episode end
absolute_indices: Absolute frame indices in the episode (num_frames,) [for fallback]
episode_length: Total length of the episode [for fallback]
Returns:
Tuple of:
@@ -149,21 +146,30 @@ def sample_reverse_video_feature(
"""
video_length = len(video_feature)
# Original logic: start from first half, end in second half, ensure min 3 frames
if video_length > 3:
# Sample start from first half
start_idx = random.randint(0, video_length // 2)
# Sample end from second half
end_idx = random.randint(video_length // 2, video_length)
# Ensure minimum 3 frames difference (original uses while loop)
while end_idx - start_idx < 3:
start_idx = random.randint(0, video_length // 2)
end_idx = random.randint(video_length // 2, video_length)
# Extract the forward segment
video_feature = video_feature[start_idx:end_idx]
video_length = len(video_feature)
# Adjust remaining_length
if remaining_length is not None:
remaining_length = remaining_length - start_idx if remaining_length > start_idx else video_length
# Generate forward progress targets using ORIGINAL ReWiND formula
# Progress = (position_in_sequence + 1) / remaining_trajectory_length
if remaining_length is not None:
# CORRECT: Use remaining length from first frame to episode end
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length
elif absolute_indices is not None and episode_length is not None:
# Fallback: Compute remaining length from first frame to episode end
first_frame_idx = absolute_indices[0].item() if isinstance(absolute_indices[0], torch.Tensor) else absolute_indices[0]
remaining_length_computed = episode_length - first_frame_idx
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length_computed
else:
# Fallback: linear progress
forward_progress = torch.linspace(1.0/video_length, 1.0, video_length)
progress_indices = torch.arange(1, video_length + 1, dtype=torch.float32)
forward_progress = progress_indices / remaining_length
# ORIGINAL LOGIC: Reverse from END backwards, then append to forward sequence
# Example: video=[A,B,C,D,E] -> reversed=[E,D,C,B,A] -> take some from reversed (skip first)
@@ -39,10 +39,11 @@ class ReWiNDConfig(PreTrainedConfig):
num_layers: int = 4
# Temporal parameters
max_length: int = 32 # Maximum video sequence length
max_length: int = 32 # Maximum video sequence length, ORIGINAL: 16!
subsample_video: bool = True # Whether to pad/subsample videos to max_length
use_temporal_sampler: bool = True # Always enable temporal sequence loading
sequence_stride: int = 1 # Stride between frames when using temporal sampler
rewind_ratio: float = 0.8 # Probability of applying rewind augmentation (original: 0.8)
# Training parameters
batch_size: int = 64
@@ -91,7 +92,7 @@ class ReWiNDConfig(PreTrainedConfig):
def get_optimizer_preset(self) -> AdamWConfig:
"""Get default optimizer configuration for ReWiND training."""
return AdamWConfig(
lr=3e-4,
lr=1e-4,
weight_decay=1e-4,
betas=(0.9, 0.999),
eps=1e-8,
@@ -100,8 +101,8 @@ class ReWiNDConfig(PreTrainedConfig):
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Get default learning rate scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=3e-4,
decay_lr=3e-5,
peak_lr=1e-4,
decay_lr=1e-5,
num_warmup_steps=1000,
num_decay_steps=100000,
)
+8 -11
View File
@@ -569,15 +569,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
else:
current_remaining_length = remaining_lengths.item() if isinstance(remaining_lengths, torch.Tensor) else remaining_lengths
if random.random() < 0.5: # 50% chance of rewind
if random.random() < self.config.rewind_ratio: # Use configurable rewind ratio
# Apply video rewind augmentation (now returns tuple)
rewound_video, progress = sample_reverse_video_feature(
video_features[i],
max_length=max_length,
random_sample=False, # Use consecutive frames, not random sampling
remaining_length=current_remaining_length,
absolute_indices=current_absolute_indices,
episode_length=current_episode_length
random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length
)
processed_videos.append(rewound_video.to(self.device))
progress_targets.append(progress.to(self.device))
@@ -586,10 +584,8 @@ class ReWiNDRewardModel(PreTrainedPolicy):
sampled_video, progress = sample_video_feature(
video_features[i],
max_length=max_length,
random_sample=False, # Use consecutive frames, not random sampling
remaining_length=current_remaining_length,
absolute_indices=current_absolute_indices,
episode_length=current_episode_length
random_sample=True, # Use random sampling (original ReWiND)
remaining_length=current_remaining_length
)
processed_videos.append(sampled_video.to(self.device))
progress_targets.append(progress.to(self.device))
@@ -623,12 +619,13 @@ class ReWiNDRewardModel(PreTrainedPolicy):
# For misaligned pairs, we don't need correct progress targets (will be set to 0)
misaligned_videos_sampled = []
for i in range(batch_size):
# For misaligned videos, use video length as remaining_length
video_len = len(misaligned_videos[i])
sampled, _ = sample_video_feature(
misaligned_videos[i],
max_length=max_length,
random_sample=True,
absolute_indices=None,
episode_length=None
remaining_length=video_len # Use video length for misaligned pairs
)
misaligned_videos_sampled.append(sampled.to(self.device))
misaligned_videos_sampled = torch.stack(misaligned_videos_sampled)