From 7440d772ff8350656e33477e96dae60638c25924 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 30 Aug 2025 12:28:18 +0200 Subject: [PATCH] fix --- .../policies/rlearn/modeling_rlearn.py | 47 +++++++++++++++++-- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/src/lerobot/policies/rlearn/modeling_rlearn.py b/src/lerobot/policies/rlearn/modeling_rlearn.py index 5ab83f4f3..c533f7b64 100644 --- a/src/lerobot/policies/rlearn/modeling_rlearn.py +++ b/src/lerobot/policies/rlearn/modeling_rlearn.py @@ -421,10 +421,8 @@ class RLearNPolicy(PreTrainedPolicy): target = augmented_target[:, idx] else: # Calculate true episode progress using episode_index and frame_index from batch - if "episode_index" in batch and "frame_index" in batch and hasattr(self, "episode_data_index"): - # Get episode indices and frame indices from batch - episode_indices = batch["episode_index"] # Shape: (B,) - frame_indices = batch["frame_index"] # Shape: (B,) + episode_indices, frame_indices = self._extract_episode_and_frame_indices(batch) + if episode_indices is not None and frame_indices is not None and self.episode_data_index is not None: # Calculate progress for the current frame in each sample progress_values = [] @@ -485,7 +483,10 @@ class RLearNPolicy(PreTrainedPolicy): target = target[:, idx] else: raise ValueError( - "No episode information found in batch. Please ensure 'episode_index' and 'frame_index' keys are present." + "No episode information found to build full-episode progress. " + "Expected 'episode_index' and 'frame_index' in batch and a valid 'episode_data_index' on the policy. " + "Please pass RLearNPolicy(episode_data_index=...) built from episodes.jsonl (per-episode lengths), " + "and ensure the dataset exposes 'episode_index' and 'frame_index' (shape (B,) or (B,1))." ) # During inference, we might not want to compute loss @@ -564,6 +565,42 @@ class RLearNPolicy(PreTrainedPolicy): return total_loss, loss_dict + def _extract_episode_and_frame_indices(self, batch: dict[str, Tensor]) -> tuple[Tensor | None, Tensor | None]: + """Try to extract (episode_index, frame_index) tensors from batch or complementary data. + + Accepts shapes (B,) or (B,1) and returns 1D long tensors on the model device. + """ + device = next(self.parameters()).device + + ep = batch.get("episode_index") + fr = batch.get("frame_index") + + # Try complementary_data + if (ep is None or fr is None) and isinstance(batch.get("complementary_data"), dict): + comp = batch["complementary_data"] + ep = comp.get("episode_index", ep) + fr = comp.get("frame_index", fr) + + if ep is None or fr is None: + return None, None + + # Convert to 1D long tensors on device + if torch.is_tensor(ep): + if ep.dim() == 2 and ep.shape[1] == 1: + ep = ep.squeeze(1) + ep = ep.to(device=device, dtype=torch.long) + else: + ep = torch.as_tensor(ep, device=device, dtype=torch.long) + + if torch.is_tensor(fr): + if fr.dim() == 2 and fr.shape[1] == 1: + fr = fr.squeeze(1) + fr = fr.to(device=device, dtype=torch.long) + else: + fr = torch.as_tensor(fr, device=device, dtype=torch.long) + + return ep, fr + # Helper functions for ReWiND architecture