change loadig subtasks

This commit is contained in:
Pepijn
2025-11-25 22:48:46 +01:00
parent 456d9fe3ff
commit 599c2477c5
+14 -8
View File
@@ -283,18 +283,24 @@ class SARMRewardModel(PreTrainedPolicy):
def _update_num_stages_from_dataset(self, dataset_meta) -> None:
"""Update num_stages in config based on dataset subtask annotations."""
episodes = dataset_meta.episodes
if "annotation.subtask.name" not in episodes:
raise ValueError("No subtask annotations found in dataset annotations")
if episodes is None or len(episodes) == 0:
raise ValueError("No episodes found, using default num_stages")
if 'subtask_names' not in episodes.column_names:
raise ValueError("No subtask annotations found in dataset, using default num_stages")
episodes_df = episodes.to_pandas()
# Collect all unique subtask names
all_subtask_names = set()
for i in range(len(episodes["annotation.subtask.name"])):
subtask_names = episodes["annotation.subtask.name"][i]
if subtask_names:
for name in subtask_names:
all_subtask_names.add(name)
for ep_idx in episodes_df.index:
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
continue
all_subtask_names.update(subtask_names)
if not all_subtask_names:
raise ValueError("No subtask names found in dataset annotations")
raise ValueError("No valid subtask names found, using default num_stages")
# Sort subtask names for consistent ordering
subtask_names = sorted(list(all_subtask_names))