mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix(rewards/topreward): fix pyproject extra typo and simplify processor (#3653)
Add lerobot[topreward] extra to all in pyproject.toml, drop the redundant labels arg in scoring, and collapse the dead-branch shape check in the encoder processor.
This commit is contained in:
@@ -287,6 +287,7 @@ all = [
|
||||
"lerobot[libero]; sys_platform == 'linux'",
|
||||
"lerobot[metaworld]",
|
||||
"lerobot[sarm]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
@@ -140,7 +140,7 @@ class TOPRewardModel(PreTrainedRewardModel):
|
||||
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs, labels=labels)
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
logits = outputs.logits[:, :-1, :]
|
||||
target_labels = labels[:, 1:]
|
||||
|
||||
@@ -286,10 +286,9 @@ class TOPRewardEncoderProcessorStep(ProcessorStep):
|
||||
padded.append(t)
|
||||
result[key] = torch.cat(padded, dim=0)
|
||||
else:
|
||||
if all(t.shape == tensors[0].shape for t in tensors):
|
||||
result[key] = torch.cat(tensors, dim=0)
|
||||
else:
|
||||
result[key] = torch.cat(tensors, dim=0)
|
||||
# Vision tensors (pixel_values_videos, image_grid_thw, etc.) are expected
|
||||
# to have matching shapes since max_frames is applied uniformly per batch
|
||||
result[key] = torch.cat(tensors, dim=0)
|
||||
|
||||
for key in encoded_list[0]:
|
||||
if key not in result:
|
||||
|
||||
@@ -3172,6 +3172,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["scipy-dep"], marker = "extra == 'wallx'" },
|
||||
{ name = "lerobot", extras = ["smolvla"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["test"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["topreward"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["training"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'eo1'" },
|
||||
{ name = "lerobot", extras = ["transformers-dep"], marker = "extra == 'groot'" },
|
||||
|
||||
Reference in New Issue
Block a user