diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index 7cb4b879f..d0a49fdd5 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -283,18 +283,24 @@ class SARMRewardModel(PreTrainedPolicy): def _update_num_stages_from_dataset(self, dataset_meta) -> None: """Update num_stages in config based on dataset subtask annotations.""" episodes = dataset_meta.episodes - if "annotation.subtask.name" not in episodes: - raise ValueError("No subtask annotations found in dataset annotations") + if episodes is None or len(episodes) == 0: + raise ValueError("No episodes found, using default num_stages") + if 'subtask_names' not in episodes.column_names: + raise ValueError("No subtask annotations found in dataset, using default num_stages") + + episodes_df = episodes.to_pandas() + + # Collect all unique subtask names all_subtask_names = set() - for i in range(len(episodes["annotation.subtask.name"])): - subtask_names = episodes["annotation.subtask.name"][i] - if subtask_names: - for name in subtask_names: - all_subtask_names.add(name) + for ep_idx in episodes_df.index: + subtask_names = episodes_df.loc[ep_idx, 'subtask_names'] + if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)): + continue + all_subtask_names.update(subtask_names) if not all_subtask_names: - raise ValueError("No subtask names found in dataset annotations") + raise ValueError("No valid subtask names found, using default num_stages") # Sort subtask names for consistent ordering subtask_names = sorted(list(all_subtask_names))