mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions
This commit is contained in:
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -187,9 +187,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
|
||||
# Extract padding keys for complementary data
|
||||
# Extract padding and task keys for complementary data
|
||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||
complementary_data = pad_keys if pad_keys else {}
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||
|
||||
return (
|
||||
observation,
|
||||
@@ -225,11 +226,14 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
"info": info,
|
||||
}
|
||||
|
||||
# Add padding data from complementary_data
|
||||
# Add padding and task data from complementary_data
|
||||
if complementary_data:
|
||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||
batch.update(pad_data)
|
||||
|
||||
if "task" in complementary_data:
|
||||
batch["task"] = complementary_data["task"]
|
||||
|
||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||
if isinstance(observation, dict):
|
||||
batch.update(observation)
|
||||
@@ -947,6 +951,40 @@ class InfoProcessor:
|
||||
return transition
|
||||
|
||||
|
||||
class ComplementaryDataProcessor:
|
||||
"""Base class for processors that modify only the complementary data of a transition.
|
||||
|
||||
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||
This class handles the boilerplate of extracting and reinserting the processed complementary data
|
||||
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
"""Process the complementary data.
|
||||
|
||||
Args:
|
||||
complementary_data: The complementary data to process
|
||||
|
||||
Returns:
|
||||
The processed complementary data
|
||||
"""
|
||||
return complementary_data
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||
complementary_data = self.complementary_data(complementary_data)
|
||||
transition = (
|
||||
transition[TransitionIndex.OBSERVATION],
|
||||
transition[TransitionIndex.ACTION],
|
||||
transition[TransitionIndex.REWARD],
|
||||
transition[TransitionIndex.DONE],
|
||||
transition[TransitionIndex.TRUNCATED],
|
||||
transition[TransitionIndex.INFO],
|
||||
complementary_data,
|
||||
)
|
||||
return transition
|
||||
|
||||
|
||||
class IdentityProcessor:
|
||||
"""Identity processor that does nothing."""
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
Reference in New Issue
Block a user