mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
Allow Groot fake RTC chunk prefetch
This commit is contained in:
@@ -342,10 +342,14 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
return loss, loss_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **_: object) -> Tensor:
|
||||
"""Predict a chunk of actions for inference by delegating to Isaac-GR00T.
|
||||
|
||||
Returns a tensor of shape (B, n_action_steps, action_dim).
|
||||
|
||||
Groot does not currently implement LeRobot's RTC guidance contract. Accept
|
||||
and ignore action-selection kwargs so the RTC engine can still use Groot as
|
||||
an async chunk producer.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
@@ -419,6 +420,13 @@ def test_groot_n1_7_can_be_selected_from_policy_config_factory_without_external_
|
||||
assert "gr00t" not in sys.modules
|
||||
|
||||
|
||||
def test_groot_predict_action_chunk_accepts_rtc_kwargs():
|
||||
signature = inspect.signature(GrootPolicy.predict_action_chunk)
|
||||
|
||||
assert any(parameter.kind is inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
signature.bind(object(), {}, inference_delay=2, prev_chunk_left_over=None)
|
||||
|
||||
|
||||
def test_groot_from_pretrained_rejects_mismatched_caller_config(tmp_path):
|
||||
model_path = tmp_path / "GR00T-N1.7-local"
|
||||
model_path.mkdir()
|
||||
|
||||
Reference in New Issue
Block a user