fix(rewards/topreward): fix pyproject extra typo and simplify processor (#3653)

Add lerobot[topreward] extra to all in
pyproject.toml, drop the redundant labels arg in scoring, and
collapse the dead-branch shape check in the encoder processor.
This commit is contained in:
Cole
2026-05-22 17:27:09 -05:00
committed by GitHub
parent 5cfca59ec7
commit 616663cd9f
4 changed files with 6 additions and 5 deletions
+1
View File
@@ -287,6 +287,7 @@ all = [
"lerobot[libero]; sys_platform == 'linux'", "lerobot[libero]; sys_platform == 'linux'",
"lerobot[metaworld]", "lerobot[metaworld]",
"lerobot[sarm]", "lerobot[sarm]",
"lerobot[topreward]",
"lerobot[peft]", "lerobot[peft]",
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2 # "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
] ]
@@ -140,7 +140,7 @@ class TOPRewardModel(PreTrainedRewardModel):
self.eval() self.eval()
with torch.no_grad(): with torch.no_grad():
outputs = self.model(**inputs, labels=labels) outputs = self.model(**inputs)
logits = outputs.logits[:, :-1, :] logits = outputs.logits[:, :-1, :]
target_labels = labels[:, 1:] target_labels = labels[:, 1:]
@@ -286,10 +286,9 @@ class TOPRewardEncoderProcessorStep(ProcessorStep):
padded.append(t) padded.append(t)
result[key] = torch.cat(padded, dim=0) result[key] = torch.cat(padded, dim=0)
else: else:
if all(t.shape == tensors[0].shape for t in tensors): # Vision tensors (pixel_values_videos, image_grid_thw, etc.) are expected
result[key] = torch.cat(tensors, dim=0) # to have matching shapes since max_frames is applied uniformly per batch
else: result[key] = torch.cat(tensors, dim=0)
result[key] = torch.cat(tensors, dim=0)
for key in encoded_list[0]: for key in encoded_list[0]:
if key not in result: if key not in result:
Generated
+1
View File
@@ -3172,6 +3172,7 @@ requires-dist = [
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" }, { name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" }, { name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" }, { name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },