diff --git a/pyproject.toml b/pyproject.toml index 4c62d965c..42b01c616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -287,6 +287,7 @@ all = [ "lerobot[libero]; sys_platform == 'linux'", "lerobot[metaworld]", "lerobot[sarm]", + "lerobot[topreward]", "lerobot[peft]", # "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2 ] diff --git a/src/lerobot/rewards/topreward/modeling_topreward.py b/src/lerobot/rewards/topreward/modeling_topreward.py index 3c2c0efa8..4f057e737 100644 --- a/src/lerobot/rewards/topreward/modeling_topreward.py +++ b/src/lerobot/rewards/topreward/modeling_topreward.py @@ -140,7 +140,7 @@ class TOPRewardModel(PreTrainedRewardModel): self.eval() with torch.no_grad(): - outputs = self.model(**inputs, labels=labels) + outputs = self.model(**inputs) logits = outputs.logits[:, :-1, :] target_labels = labels[:, 1:] diff --git a/src/lerobot/rewards/topreward/processor_topreward.py b/src/lerobot/rewards/topreward/processor_topreward.py index 160f87fc5..6bf73fcb4 100644 --- a/src/lerobot/rewards/topreward/processor_topreward.py +++ b/src/lerobot/rewards/topreward/processor_topreward.py @@ -286,10 +286,9 @@ class TOPRewardEncoderProcessorStep(ProcessorStep): padded.append(t) result[key] = torch.cat(padded, dim=0) else: - if all(t.shape == tensors[0].shape for t in tensors): - result[key] = torch.cat(tensors, dim=0) - else: - result[key] = torch.cat(tensors, dim=0) + # Vision tensors (pixel_values_videos, image_grid_thw, etc.) are expected + # to have matching shapes since max_frames is applied uniformly per batch + result[key] = torch.cat(tensors, dim=0) for key in encoded_list[0]: if key not in result: diff --git a/uv.lock b/uv.lock index 3da9e9e43..2c793451a 100644 --- a/uv.lock +++ b/uv.lock @@ -3172,6 +3172,7 @@ requires-dist = [ { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["test"], marker = "extra == 'all'" }, + { name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["training"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },