mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions
This commit is contained in:
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|||||||
@@ -187,9 +187,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
|||||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||||
observation = observation_keys if observation_keys else None
|
observation = observation_keys if observation_keys else None
|
||||||
|
|
||||||
# Extract padding keys for complementary data
|
# Extract padding and task keys for complementary data
|
||||||
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
|
||||||
complementary_data = pad_keys if pad_keys else {}
|
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||||
|
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
observation,
|
observation,
|
||||||
@@ -225,11 +226,14 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
|||||||
"info": info,
|
"info": info,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add padding data from complementary_data
|
# Add padding and task data from complementary_data
|
||||||
if complementary_data:
|
if complementary_data:
|
||||||
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
|
||||||
batch.update(pad_data)
|
batch.update(pad_data)
|
||||||
|
|
||||||
|
if "task" in complementary_data:
|
||||||
|
batch["task"] = complementary_data["task"]
|
||||||
|
|
||||||
# Handle observation - flatten dict to observation.* keys if it's a dict
|
# Handle observation - flatten dict to observation.* keys if it's a dict
|
||||||
if isinstance(observation, dict):
|
if isinstance(observation, dict):
|
||||||
batch.update(observation)
|
batch.update(observation)
|
||||||
@@ -947,6 +951,40 @@ class InfoProcessor:
|
|||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
|
||||||
|
class ComplementaryDataProcessor:
|
||||||
|
"""Base class for processors that modify only the complementary data of a transition.
|
||||||
|
|
||||||
|
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
|
||||||
|
This class handles the boilerplate of extracting and reinserting the processed complementary data
|
||||||
|
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def complementary_data(self, complementary_data):
|
||||||
|
"""Process the complementary data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
complementary_data: The complementary data to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed complementary data
|
||||||
|
"""
|
||||||
|
return complementary_data
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
|
||||||
|
complementary_data = self.complementary_data(complementary_data)
|
||||||
|
transition = (
|
||||||
|
transition[TransitionIndex.OBSERVATION],
|
||||||
|
transition[TransitionIndex.ACTION],
|
||||||
|
transition[TransitionIndex.REWARD],
|
||||||
|
transition[TransitionIndex.DONE],
|
||||||
|
transition[TransitionIndex.TRUNCATED],
|
||||||
|
transition[TransitionIndex.INFO],
|
||||||
|
complementary_data,
|
||||||
|
)
|
||||||
|
return transition
|
||||||
|
|
||||||
|
|
||||||
class IdentityProcessor:
|
class IdentityProcessor:
|
||||||
"""Identity processor that does nothing."""
|
"""Identity processor that does nothing."""
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user