diff --git a/oxe_utils/transforms.py b/oxe_utils/transforms.py index a66d1ce..c845fc6 100644 --- a/oxe_utils/transforms.py +++ b/oxe_utils/transforms.py @@ -617,6 +617,7 @@ def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]: + trajectory["observation"]["gripper"] = trajectory["observation"]["gripper"][:, None] return trajectory