This commit is contained in:
Pepijn
2025-08-30 12:28:18 +02:00
parent a4fc02a636
commit 7440d772ff
+42 -5
View File
@@ -421,10 +421,8 @@ class RLearNPolicy(PreTrainedPolicy):
target = augmented_target[:, idx]
else:
# Calculate true episode progress using episode_index and frame_index from batch
if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"):
# Get episode indices and frame indices from batch
episode_indices = batch["episode_index"] # Shape: (B,)
frame_indices = batch["frame_index"] # Shape: (B,)
episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch)
if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None:
# Calculate progress for the current frame in each sample
progress_values = []
@@ -485,7 +483,10 @@ class RLearNPolicy(PreTrainedPolicy):
target = target[:, idx]
else:
raise ValueError(
"No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present."
"No episode information found to build full-episode progress. "
"Expected 'episode_index' and 'frame_index' in batch and a valid 'episode_data_index' on the policy. "
"Please pass RLearNPolicy(episode_data_index=...) built from episodes.jsonl (per-episode lengths), "
"and ensure the dataset exposes 'episode_index' and 'frame_index' (shape (B,) or (B,1))."
)
# During inference, we might not want to compute loss
@@ -564,6 +565,42 @@ class RLearNPolicy(PreTrainedPolicy):
return total_loss, loss_dict
def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]:
"""Try to extract (episode_index, frame_index) tensors from batch or complementary data.
Accepts shapes (B,) or (B,1) and returns 1D long tensors on the model device.
"""
device = next(self.parameters()).device
ep = batch.get("episode_index")
fr = batch.get("frame_index")
# Try complementary_data
if (ep is None or fr is None) and isinstance(batch.get("complementary_data"), dict):
comp = batch["complementary_data"]
ep = comp.get("episode_index", ep)
fr = comp.get("frame_index", fr)
if ep is None or fr is None:
return None, None
# Convert to 1D long tensors on device
if torch.is_tensor(ep):
if ep.dim() == 2 and ep.shape[1] == 1:
ep = ep.squeeze(1)
ep = ep.to(device=device, dtype=torch.long)
else:
ep = torch.as_tensor(ep, device=device, dtype=torch.long)
if torch.is_tensor(fr):
if fr.dim() == 2 and fr.shape[1] == 1:
fr = fr.squeeze(1)
fr = fr.to(device=device, dtype=torch.long)
else:
fr = torch.as_tensor(fr, device=device, dtype=torch.long)
return ep, fr
# Helper functions for ReWiND architecture