diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index abcbb8a8c..381c5fbd6 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING @@ -174,17 +173,14 @@ N_COLOR_CHANNELS = 3 # config -@dataclass class GR00TN15Config(PretrainedConfig): model_type = "gr00t_n1_5" - backbone_cfg: dict = field(init=False, metadata={"help": "Backbone configuration."}) - action_head_cfg: dict = field(init=False, metadata={"help": "Action head configuration."}) - - action_horizon: int = field(init=False, metadata={"help": "Action horizon."}) - - action_dim: int = field(init=False, metadata={"help": "Action dimension."}) - compute_dtype: str = field(default="float32", metadata={"help": "Compute dtype."}) + backbone_cfg: dict + action_head_cfg: dict + action_horizon: int + action_dim: int + compute_dtype: str = "float32" def __init__(self, **kwargs): super().__init__(**kwargs)