diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 61a856f0b..d6b995ca8 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -40,7 +40,7 @@ T = TypeVar("T", bound="PreTrainedPolicy") class ActionSelectKwargs(TypedDict, total=False): noise: Tensor | None - return_extra: bool + return_intermediate_predictions: bool class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): @@ -196,9 +196,10 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): Child classes using action chunking should use this method within `select_action` to form the action chunk cached for selection. - By default returns just the action `Tensor`. If `return_extra=True`, returns `(action, extra)` - where `extra` is a (possibly empty) `dict[str, Tensor]` of auxiliary outputs a policy may - expose (e.g. world-model predictions). Policies that produce nothing extra may ignore the kwarg. + By default returns just the action `Tensor`. If `return_intermediate_predictions=True`, + returns `(action, predictions)` where `predictions` is a (possibly empty) `dict[str, Tensor]` + of additional model predictions a policy may expose (e.g. world-model predicted frames). + Policies that produce nothing extra may ignore the kwarg. """ raise NotImplementedError @@ -211,9 +212,10 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): When the model uses a history of observations, or outputs a sequence of actions, this method deals with caching. - By default returns just the action `Tensor`. If `return_extra=True`, returns `(action, extra)` - where `extra` is a (possibly empty) `dict[str, Tensor]` of auxiliary outputs a policy may - expose (e.g. world-model predictions). Policies that produce nothing extra may ignore the kwarg. + By default returns just the action `Tensor`. If `return_intermediate_predictions=True`, + returns `(action, predictions)` where `predictions` is a (possibly empty) `dict[str, Tensor]` + of additional model predictions a policy may expose (e.g. world-model predicted frames). + Policies that produce nothing extra may ignore the kwarg. """ raise NotImplementedError