mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
change loadig subtasks
This commit is contained in:
@@ -283,18 +283,24 @@ class SARMRewardModel(PreTrainedPolicy):
|
||||
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
|
||||
"""Update num_stages in config based on dataset subtask annotations."""
|
||||
episodes = dataset_meta.episodes
|
||||
if "annotation.subtask.name" not in episodes:
|
||||
raise ValueError("No subtask annotations found in dataset annotations")
|
||||
if episodes is None or len(episodes) == 0:
|
||||
raise ValueError("No episodes found, using default num_stages")
|
||||
|
||||
if 'subtask_names' not in episodes.column_names:
|
||||
raise ValueError("No subtask annotations found in dataset, using default num_stages")
|
||||
|
||||
episodes_df = episodes.to_pandas()
|
||||
|
||||
# Collect all unique subtask names
|
||||
all_subtask_names = set()
|
||||
for i in range(len(episodes["annotation.subtask.name"])):
|
||||
subtask_names = episodes["annotation.subtask.name"][i]
|
||||
if subtask_names:
|
||||
for name in subtask_names:
|
||||
all_subtask_names.add(name)
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
continue
|
||||
all_subtask_names.update(subtask_names)
|
||||
|
||||
if not all_subtask_names:
|
||||
raise ValueError("No subtask names found in dataset annotations")
|
||||
raise ValueError("No valid subtask names found, using default num_stages")
|
||||
|
||||
# Sort subtask names for consistent ordering
|
||||
subtask_names = sorted(list(all_subtask_names))
|
||||
|
||||
Reference in New Issue
Block a user