mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
refactor(processor): improve processor pipeline typing with generic type (#1810)
* refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
@@ -20,6 +20,7 @@ from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -28,8 +29,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_act_pre_post_processors(
|
||||
config: ACTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: ACTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -46,6 +55,16 @@ def make_act_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,8 +30,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_diffusion_pre_post_processors(
|
||||
config: DiffusionConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: DiffusionConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -47,6 +56,15 @@ def make_diffusion_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor.pipeline import RobotProcessor
|
||||
from lerobot.processor.pipeline import ProcessorKwargs, RobotProcessor
|
||||
|
||||
|
||||
def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
@@ -114,6 +114,8 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
preprocessor_overrides: dict[str, Any] | None
|
||||
postprocessor_overrides: dict[str, Any] | None
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None
|
||||
preprocessor_kwargs: ProcessorKwargs | None
|
||||
postprocessor_kwargs: ProcessorKwargs | None
|
||||
|
||||
|
||||
def make_pre_post_processors(
|
||||
@@ -139,16 +141,24 @@ def make_pre_post_processors(
|
||||
NotImplementedError: If the policy type doesn't have a processor implemented.
|
||||
"""
|
||||
if pretrained_path:
|
||||
# Extract preprocessor and postprocessor kwargs
|
||||
preprocessor_kwargs = kwargs.get("preprocessor_kwargs", {})
|
||||
postprocessor_kwargs = kwargs.get("postprocessor_kwargs", {})
|
||||
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||
to_output=preprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||
to_output=postprocessor_kwargs.get("to_output"),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -157,61 +167,90 @@ def make_pre_post_processors(
|
||||
from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors
|
||||
|
||||
processors = make_tdmpc_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, DiffusionConfig):
|
||||
from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors
|
||||
|
||||
processors = make_diffusion_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, ACTConfig):
|
||||
from lerobot.policies.act.processor_act import make_act_pre_post_processors
|
||||
|
||||
processors = make_act_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, VQBeTConfig):
|
||||
from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors
|
||||
|
||||
processors = make_vqbet_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
|
||||
|
||||
processors = make_pi0_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, PI0Config):
|
||||
elif isinstance(policy_cfg, PI0FASTConfig):
|
||||
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
|
||||
|
||||
processors = make_pi0fast_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SACConfig):
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
|
||||
processors = make_sac_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, RewardClassifierConfig):
|
||||
from lerobot.policies.sac.reward_model.processor_classifier import make_classifier_processor
|
||||
|
||||
processors = make_classifier_processor(config=policy_cfg, dataset_stats=kwargs.get("dataset_stats"))
|
||||
processors = make_classifier_processor(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SmolVLAConfig):
|
||||
from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors
|
||||
|
||||
processors = make_smolvla_pre_post_processors(
|
||||
config=policy_cfg, dataset_stats=kwargs.get("dataset_stats")
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
preprocessor_kwargs=kwargs.get("preprocessor_kwargs"),
|
||||
postprocessor_kwargs=kwargs.get("postprocessor_kwargs"),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -22,6 +22,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
TokenizerProcessor,
|
||||
@@ -65,8 +66,16 @@ class Pi0NewLineProcessor(ComplementaryDataProcessor):
|
||||
|
||||
|
||||
def make_pi0_pre_post_processors(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
# Add remaining processors
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
@@ -93,6 +102,15 @@ def make_pi0_pre_post_processors(
|
||||
),
|
||||
]
|
||||
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,8 +30,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_pi0fast_pre_post_processors(
|
||||
config: PI0Config, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: PI0Config,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
@@ -47,6 +56,15 @@ def make_pi0fast_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -30,8 +31,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_sac_pre_post_processors(
|
||||
config: SACConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: SACConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -48,6 +57,15 @@ def make_sac_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -20,13 +20,22 @@ from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
IdentityProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RobotProcessor,
|
||||
)
|
||||
|
||||
|
||||
def make_classifier_processor(
|
||||
config: RewardClassifierConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: RewardClassifierConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
NormalizerProcessor(
|
||||
features=config.input_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
@@ -37,6 +46,16 @@ def make_classifier_processor(
|
||||
DeviceProcessor(device=config.device),
|
||||
]
|
||||
output_steps = [DeviceProcessor(device="cpu"), IdentityProcessor()]
|
||||
return RobotProcessor(steps=input_steps, name="classifier_preprocessor"), RobotProcessor(
|
||||
steps=output_steps, name="classifier_postprocessor"
|
||||
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name="classifier_preprocessor",
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name="classifier_postprocessor",
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -34,8 +35,16 @@ from lerobot.processor.pipeline import (
|
||||
|
||||
|
||||
def make_smolvla_pre_post_processors(
|
||||
config: SmolVLAConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # To mimic the same processor as pretrained one
|
||||
NormalizerProcessor(
|
||||
@@ -59,8 +68,17 @@ def make_smolvla_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -29,8 +30,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_tdmpc_pre_post_processors(
|
||||
config: TDMPCConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}),
|
||||
NormalizerProcessor(
|
||||
@@ -47,6 +56,15 @@ def make_tdmpc_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.processor import (
|
||||
DeviceProcessor,
|
||||
NormalizerProcessor,
|
||||
ProcessorKwargs,
|
||||
RenameProcessor,
|
||||
RobotProcessor,
|
||||
ToBatchProcessor,
|
||||
@@ -30,8 +31,16 @@ from lerobot.processor import (
|
||||
|
||||
|
||||
def make_vqbet_pre_post_processors(
|
||||
config: VQBeTConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||
config: VQBeTConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
preprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
postprocessor_kwargs: ProcessorKwargs | None = None,
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
if preprocessor_kwargs is None:
|
||||
preprocessor_kwargs = {}
|
||||
if postprocessor_kwargs is None:
|
||||
postprocessor_kwargs = {}
|
||||
|
||||
input_steps = [
|
||||
RenameProcessor(rename_map={}), # Let the possibility to the user to rename the keys
|
||||
NormalizerProcessor(
|
||||
@@ -48,6 +57,15 @@ def make_vqbet_pre_post_processors(
|
||||
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor(
|
||||
steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME
|
||||
return (
|
||||
RobotProcessor(
|
||||
steps=input_steps,
|
||||
name=PREPROCESSOR_DEFAULT_NAME,
|
||||
**preprocessor_kwargs,
|
||||
),
|
||||
RobotProcessor(
|
||||
steps=output_steps,
|
||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
||||
**postprocessor_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ from .pipeline import (
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorKwargs,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
@@ -68,6 +69,7 @@ __all__ = [
|
||||
"UnnormalizerProcessor",
|
||||
"hotswap_stats",
|
||||
"ObservationProcessor",
|
||||
"ProcessorKwargs",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
|
||||
@@ -24,7 +24,7 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Generic, TypedDict, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelHubMixin, hf_hub_download
|
||||
@@ -33,6 +33,9 @@ from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
|
||||
# Type variable for generic processor output type
|
||||
TOutput = TypeVar("TOutput")
|
||||
|
||||
|
||||
class TransitionKey(str, Enum):
|
||||
"""Keys for accessing EnvTransition dictionary components."""
|
||||
@@ -216,6 +219,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
|
||||
metadata without breaking the processor.
|
||||
"""
|
||||
|
||||
# Validate input type
|
||||
if not isinstance(batch, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(batch).__name__}")
|
||||
|
||||
# Extract observation keys
|
||||
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
observation = observation_keys if observation_keys else None
|
||||
@@ -279,8 +286,15 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
|
||||
return batch
|
||||
|
||||
|
||||
class ProcessorKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for RobotProcessor constructor."""
|
||||
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None
|
||||
to_output: Callable[[EnvTransition], Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotProcessor(ModelHubMixin):
|
||||
class RobotProcessor(ModelHubMixin, Generic[TOutput]):
|
||||
"""
|
||||
Composable, debuggable post-processing processor for robot transitions.
|
||||
|
||||
@@ -288,20 +302,43 @@ class RobotProcessor(ModelHubMixin):
|
||||
left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts
|
||||
and batch dictionaries, automatically converting between formats as needed.
|
||||
|
||||
The processor is generic over its output type TOutput, which provides better type safety
|
||||
and clarity about what the processor returns.
|
||||
|
||||
Args:
|
||||
steps: Ordered list of processing steps executed on every call. Defaults to empty list.
|
||||
name: Human-readable identifier that is persisted inside the JSON config.
|
||||
Defaults to "RobotProcessor".
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format.
|
||||
Usually it is a batch dict or EnvTransition dict.
|
||||
Defaults to _default_transition_to_batch.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format of type TOutput.
|
||||
Defaults to _default_transition_to_batch (returns batch dict).
|
||||
Use identity function (lambda x: x) for EnvTransition output.
|
||||
before_step_hooks: List of hooks called before each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||
index and transition, and can optionally return a modified transition.
|
||||
|
||||
Type Safety Examples:
|
||||
```python
|
||||
# Default behavior - returns batch dict
|
||||
processor: RobotProcessor[dict[str, Any]] = RobotProcessor(steps=[some_step1, some_step2])
|
||||
result: dict[str, Any] = processor(batch_data) # Type checker knows this is a dict
|
||||
|
||||
# For EnvTransition output, explicitly specify identity function
|
||||
transition_processor: RobotProcessor[EnvTransition] = RobotProcessor(
|
||||
steps=[some_step1, some_step2],
|
||||
to_output=lambda x: x, # Identity function
|
||||
)
|
||||
result: EnvTransition = transition_processor(batch_data) # Type checker knows this is EnvTransition
|
||||
|
||||
# For custom output types
|
||||
processor: RobotProcessor[str] = RobotProcessor(
|
||||
steps=[custom_step], to_output=lambda t: f"Processed {len(t)} keys"
|
||||
)
|
||||
result: str = processor(batch_data) # Type checker knows this is str
|
||||
```
|
||||
|
||||
Hook Semantics:
|
||||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||
reorder hooks after registration without creating a new pipeline.
|
||||
@@ -323,8 +360,13 @@ class RobotProcessor(ModelHubMixin):
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] = field(
|
||||
default_factory=lambda: _default_batch_to_transition, repr=False
|
||||
)
|
||||
to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field(
|
||||
default_factory=lambda: _default_transition_to_batch, repr=False
|
||||
to_output: Callable[[EnvTransition], TOutput] = field(
|
||||
# Cast is necessary here: Working around Python type-checker limitation.
|
||||
# _default_transition_to_batch returns dict[str, Any], but we need it to be TOutput
|
||||
# for the generic to work. When no explicit type is given, TOutput defaults to dict[str, Any],
|
||||
# making this cast safe.
|
||||
default_factory=lambda: cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
repr=False,
|
||||
)
|
||||
|
||||
# Processor-level hooks for observation/monitoring
|
||||
@@ -332,98 +374,57 @@ class RobotProcessor(ModelHubMixin):
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
def __call__(self, data: dict[str, Any]) -> TOutput:
|
||||
"""Process data through all steps.
|
||||
|
||||
The method accepts either the classic EnvTransition dict or a batch dictionary
|
||||
(like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied
|
||||
it is first converted to the internal dict format using to_transition; after all
|
||||
steps are executed the dict is transformed back into a batch dict with to_batch and the
|
||||
result is returned – thereby preserving the caller's original data type.
|
||||
The method accepts a batch dictionary (like the ones returned by ReplayBuffer or
|
||||
LeRobotDataset). It is first converted to EnvTransition format using to_transition,
|
||||
then processed through all steps, and finally converted to the output format using to_output.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
data: A batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
The processed data in the same format as the input (EnvTransition or batch dict).
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
The processed data in the format specified by to_output.
|
||||
"""
|
||||
# Check if we need to convert back to batch format at the end
|
||||
_, called_with_batch = self._prepare_transition(data)
|
||||
# Always convert input through to_transition
|
||||
transition = self.to_transition(data)
|
||||
|
||||
# Use step_through to get the iterator
|
||||
step_iterator = self.step_through(data)
|
||||
|
||||
# Get initial state (before any steps)
|
||||
current_transition = next(step_iterator)
|
||||
|
||||
# Process each step with hooks
|
||||
for idx, next_transition in enumerate(step_iterator):
|
||||
# Apply before hooks with current state (before step execution)
|
||||
# Process through all steps
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
# Apply before hooks
|
||||
for hook in self.before_step_hooks:
|
||||
hook(idx, current_transition)
|
||||
hook(idx, transition)
|
||||
|
||||
# Move to next state (after step execution)
|
||||
current_transition = next_transition
|
||||
# Execute step
|
||||
transition = processor_step(transition)
|
||||
|
||||
# Apply after hooks with updated state
|
||||
# Apply after hooks
|
||||
for hook in self.after_step_hooks:
|
||||
hook(idx, current_transition)
|
||||
hook(idx, transition)
|
||||
|
||||
# Convert back to original format if needed
|
||||
if called_with_batch or self.to_output is not _default_transition_to_batch:
|
||||
return self.to_output(current_transition)
|
||||
else:
|
||||
return current_transition
|
||||
# Always use to_output for consistent typing
|
||||
return self.to_output(transition)
|
||||
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Returns:
|
||||
A tuple of (prepared_transition, called_with_batch_flag)
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
return transition, called_with_batch
|
||||
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
def step_through(self, data: dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||||
and yields the intermediate results. This allows users to debug the pipeline or
|
||||
apply custom logic between steps if needed.
|
||||
|
||||
Note: This method always yields EnvTransition objects regardless of input format.
|
||||
If you need the results in the original input format, you'll need to convert them
|
||||
Note: This method always yields EnvTransition objects regardless of output format.
|
||||
If you need the results in the output format, you'll need to convert them
|
||||
using `to_output()`.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
data: A batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate EnvTransition results after each step.
|
||||
"""
|
||||
transition, _ = self._prepare_transition(data)
|
||||
# Always convert input through to_transition
|
||||
transition = self.to_transition(data)
|
||||
|
||||
# Yield initial state
|
||||
yield transition
|
||||
@@ -525,8 +526,10 @@ class RobotProcessor(ModelHubMixin):
|
||||
revision: str | None = None,
|
||||
config_filename: str | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
to_transition: Callable[[dict[str, Any]], EnvTransition] | None = None,
|
||||
to_output: Callable[[EnvTransition], TOutput] | None = None,
|
||||
**kwargs,
|
||||
) -> RobotProcessor:
|
||||
) -> RobotProcessor[TOutput]:
|
||||
"""Load a serialized processor from source (local path or Hugging Face Hub identifier).
|
||||
|
||||
Args:
|
||||
@@ -540,9 +543,14 @@ class RobotProcessor(ModelHubMixin):
|
||||
(for registered steps). Values are dictionaries containing parameter overrides
|
||||
that will be merged with the saved configuration. This is useful for providing
|
||||
non-serializable objects like environment instances.
|
||||
to_transition: Function to convert batch dict to EnvTransition dict.
|
||||
Defaults to _default_batch_to_transition.
|
||||
to_output: Function to convert EnvTransition dict to the desired output format of type T.
|
||||
Defaults to _default_transition_to_batch (returns batch dict).
|
||||
Use identity function (lambda x: x) for EnvTransition output.
|
||||
|
||||
Returns:
|
||||
A RobotProcessor instance loaded from the saved configuration.
|
||||
A RobotProcessor[TOutput] instance loaded from the saved configuration.
|
||||
|
||||
Raises:
|
||||
ImportError: If a processor step class cannot be loaded or imported.
|
||||
@@ -756,19 +764,34 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Make sure override keys match exact step class names or registry names."
|
||||
)
|
||||
|
||||
return cls(steps, loaded_config.get("name", "RobotProcessor"))
|
||||
return cls(
|
||||
steps=steps,
|
||||
name=loaded_config.get("name", "RobotProcessor"),
|
||||
to_transition=to_transition or _default_batch_to_transition,
|
||||
# Cast is necessary here: Same type-checker limitation as above.
|
||||
# When to_output is None, we use the default which returns dict[str, Any].
|
||||
# The cast ensures type consistency with the generic TOutput parameter.
|
||||
to_output=to_output or cast(Callable[[EnvTransition], TOutput], _default_transition_to_batch),
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of steps in the processor."""
|
||||
return len(self.steps)
|
||||
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor:
|
||||
def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor[TOutput]:
|
||||
"""Indexing helper exposing underlying steps.
|
||||
* ``int`` – returns the idx-th ProcessorStep.
|
||||
* ``slice`` – returns a new RobotProcessor with the sliced steps.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return RobotProcessor(self.steps[idx], self.name)
|
||||
return RobotProcessor(
|
||||
steps=self.steps[idx],
|
||||
name=self.name,
|
||||
to_transition=self.to_transition,
|
||||
to_output=self.to_output,
|
||||
before_step_hooks=self.before_step_hooks.copy(),
|
||||
after_step_hooks=self.after_step_hooks.copy(),
|
||||
)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
|
||||
@@ -102,7 +102,12 @@ def test_act_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -131,7 +136,12 @@ def test_act_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -160,7 +170,12 @@ def test_act_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -223,14 +238,21 @@ def test_act_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -249,7 +271,12 @@ def test_act_processor_device_placement_preservation():
|
||||
|
||||
# Test with CPU config
|
||||
config.device = "cpu"
|
||||
preprocessor, _ = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, _ = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Process CPU data
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
@@ -269,12 +296,21 @@ def test_act_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Modify the device processor to use float16
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
modified_steps = []
|
||||
for step in preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
preprocessor.steps = modified_steps
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)}
|
||||
@@ -294,7 +330,12 @@ def test_act_processor_batch_consistency():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_act_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test single sample (unbatched)
|
||||
observation = {OBS_STATE: torch.randn(7)}
|
||||
|
||||
@@ -245,7 +245,7 @@ def test_mixed_observation():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test ToBatchProcessor integration with RobotProcessor."""
|
||||
to_batch_processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([to_batch_processor])
|
||||
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create unbatched observation
|
||||
observation = {
|
||||
@@ -285,7 +285,9 @@ def test_serialization_methods():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor], name="BatchPipeline")
|
||||
pipeline = RobotProcessor(
|
||||
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
@@ -296,7 +298,9 @@ def test_save_and_load_pretrained():
|
||||
assert config_path.exists()
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "BatchPipeline"
|
||||
assert len(loaded_pipeline) == 1
|
||||
@@ -323,11 +327,13 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test saving and loading using registry name."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor])
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Verify the loaded processor works
|
||||
observation = {
|
||||
|
||||
@@ -97,7 +97,12 @@ def test_classifier_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -123,7 +128,12 @@ def test_classifier_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -156,7 +166,12 @@ def test_classifier_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -230,14 +245,22 @@ def test_classifier_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -260,13 +283,19 @@ def test_classifier_processor_mixed_precision():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_classifier_processor(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -290,7 +319,12 @@ def test_classifier_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 16
|
||||
@@ -315,7 +349,12 @@ def test_classifier_processor_postprocessor_identity():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_classifier_processor(config, stats)
|
||||
preprocessor, postprocessor = make_classifier_processor(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data for postprocessor
|
||||
reward = torch.tensor([[0.8], [0.3], [0.9]]) # Batch of rewards/predictions
|
||||
|
||||
@@ -311,7 +311,12 @@ def test_integration_with_robot_processor():
|
||||
device_processor = DeviceProcessor(device="cpu")
|
||||
batch_processor = ToBatchProcessor()
|
||||
|
||||
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
|
||||
processor = RobotProcessor(
|
||||
steps=[batch_processor, device_processor],
|
||||
name="test_pipeline",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -985,6 +990,8 @@ def test_policy_processor_integration():
|
||||
DeviceProcessor(device="cuda"),
|
||||
],
|
||||
name="test_preprocessor",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Create output processor (postprocessor) that moves to CPU
|
||||
@@ -994,6 +1001,8 @@ def test_policy_processor_integration():
|
||||
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
|
||||
],
|
||||
name="test_postprocessor",
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test data on CPU
|
||||
|
||||
@@ -105,7 +105,12 @@ def test_diffusion_processor_with_images():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data with images
|
||||
observation = {
|
||||
@@ -131,7 +136,12 @@ def test_diffusion_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -164,7 +174,12 @@ def test_diffusion_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -238,14 +253,22 @@ def test_diffusion_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -268,13 +291,19 @@ def test_diffusion_processor_mixed_precision():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
modified_steps = []
|
||||
for step in factory_preprocessor.steps:
|
||||
if isinstance(step, DeviceProcessor):
|
||||
preprocessor.steps[i] = DeviceProcessor(device=config.device, float_dtype="float16")
|
||||
modified_steps.append(DeviceProcessor(device=config.device, float_dtype="float16"))
|
||||
else:
|
||||
modified_steps.append(step)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -298,7 +327,12 @@ def test_diffusion_processor_identity_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
image_value = torch.rand(3, 224, 224) * 255 # Large values
|
||||
@@ -322,7 +356,12 @@ def test_diffusion_processor_batch_consistency():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_diffusion_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with different batch sizes
|
||||
for batch_size in [1, 8, 32]:
|
||||
|
||||
@@ -506,7 +506,7 @@ def test_get_config(full_stats):
|
||||
|
||||
def test_integration_with_robot_processor(normalizer_processor):
|
||||
"""Test integration with RobotProcessor pipeline"""
|
||||
robot_processor = RobotProcessor([normalizer_processor])
|
||||
robot_processor = RobotProcessor([normalizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
@@ -1317,7 +1317,7 @@ def test_hotswap_stats_functional_test():
|
||||
|
||||
# Create original processor
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=initial_stats)
|
||||
original_processor = RobotProcessor(steps=[normalizer])
|
||||
original_processor = RobotProcessor(steps=[normalizer], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Process with original stats
|
||||
original_result = original_processor(transition)
|
||||
|
||||
@@ -84,7 +84,12 @@ def test_make_pi0_processor_basic():
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -183,7 +188,12 @@ def test_pi0_processor_cuda():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -233,7 +243,12 @@ def test_pi0_processor_accelerate_scenario():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -284,7 +299,12 @@ def test_pi0_processor_multi_gpu():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -310,7 +330,12 @@ def test_pi0_processor_without_stats():
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_pi0_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
|
||||
@@ -176,7 +176,7 @@ class MockStepWithTensorState:
|
||||
|
||||
def test_empty_pipeline():
|
||||
"""Test pipeline with no steps."""
|
||||
pipeline = RobotProcessor()
|
||||
pipeline = RobotProcessor([], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -188,7 +188,7 @@ def test_empty_pipeline():
|
||||
def test_single_step_pipeline():
|
||||
"""Test pipeline with a single step."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
pipeline = RobotProcessor([step], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -205,7 +205,7 @@ def test_multiple_steps_pipeline():
|
||||
"""Test pipeline with multiple steps."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
pipeline = RobotProcessor([step1, step2], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
@@ -557,7 +557,9 @@ def test_save_and_load_pretrained():
|
||||
def test_step_without_optional_methods():
|
||||
"""Test pipeline with steps that don't implement optional methods."""
|
||||
step = MockStepWithoutOptionalMethods(multiplier=3.0)
|
||||
pipeline = RobotProcessor([step])
|
||||
pipeline = RobotProcessor(
|
||||
[step], to_transition=lambda x: x, to_output=lambda x: x
|
||||
) # Identity for EnvTransition input/output
|
||||
|
||||
transition = create_transition(reward=2.0)
|
||||
result = pipeline(transition)
|
||||
@@ -878,7 +880,9 @@ def test_from_pretrained_with_overrides():
|
||||
"registered_mock_step": {"device": "cuda", "value": 200},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Verify the pipeline was loaded correctly
|
||||
assert len(loaded_pipeline) == 2
|
||||
@@ -914,7 +918,9 @@ def test_from_pretrained_with_partial_overrides():
|
||||
|
||||
# The current implementation applies overrides to ALL steps with the same class name
|
||||
# Both steps will get the override
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
transition = create_transition(reward=1.0)
|
||||
result = loaded_pipeline(transition)
|
||||
@@ -971,7 +977,9 @@ def test_from_pretrained_registered_step_override():
|
||||
# Override using registry name
|
||||
overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that overrides were applied
|
||||
transition = create_transition()
|
||||
@@ -999,7 +1007,9 @@ def test_from_pretrained_mixed_registered_and_unregistered():
|
||||
"registered_mock_step": {"value": 777},
|
||||
}
|
||||
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides=overrides, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test both steps
|
||||
transition = create_transition(reward=2.0)
|
||||
@@ -1020,7 +1030,9 @@ def test_from_pretrained_no_overrides():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load without overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
@@ -1040,7 +1052,9 @@ def test_from_pretrained_empty_overrides():
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Load with empty overrides
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={})
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={}, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
@@ -1455,6 +1469,8 @@ def test_override_with_nested_config():
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir,
|
||||
overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test that override worked
|
||||
@@ -1553,7 +1569,10 @@ def test_override_with_callables():
|
||||
|
||||
# Load with callable override
|
||||
loaded = RobotProcessor.from_pretrained(
|
||||
tmp_dir, overrides={"callable_step": {"transform_fn": double_values}}
|
||||
tmp_dir,
|
||||
overrides={"callable_step": {"transform_fn": double_values}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test it works
|
||||
@@ -1857,7 +1876,8 @@ def test_save_load_with_custom_converter_functions():
|
||||
|
||||
# Should work with standard format (wouldn't work with custom converter)
|
||||
result = loaded(batch)
|
||||
assert "observation.image" in result # Standard format preserved
|
||||
# With new behavior, default to_output is _default_transition_to_batch, so result is batch dict
|
||||
assert "observation.image" in result
|
||||
|
||||
|
||||
class NonCompliantStep:
|
||||
|
||||
@@ -188,7 +188,7 @@ def test_integration_with_robot_processor():
|
||||
}
|
||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
pipeline = RobotProcessor([rename_processor])
|
||||
pipeline = RobotProcessor([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
@@ -236,7 +236,9 @@ def test_save_and_load_pretrained():
|
||||
assert len(state_files) == 0
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||
assert len(loaded_pipeline) == 1
|
||||
@@ -277,7 +279,7 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = RobotProcessor([processor])
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
@@ -318,7 +320,7 @@ def test_chained_rename_processors():
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
pipeline = RobotProcessor([processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
observation = {
|
||||
"pos": np.array([1.0, 2.0]),
|
||||
|
||||
@@ -78,7 +78,12 @@ def test_make_sac_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -102,7 +107,12 @@ def test_sac_processor_normalization_modes():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
@@ -133,7 +143,12 @@ def test_sac_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -162,7 +177,12 @@ def test_sac_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -185,7 +205,12 @@ def test_sac_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -205,7 +230,22 @@ def test_sac_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -225,14 +265,21 @@ def test_sac_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -252,7 +299,12 @@ def test_sac_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -277,7 +329,12 @@ def test_sac_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 32
|
||||
@@ -298,7 +355,12 @@ def test_sac_processor_edge_cases():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
|
||||
@@ -89,7 +89,12 @@ def test_make_smolvla_processor_basic():
|
||||
stats = create_default_stats()
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -188,7 +193,12 @@ def test_smolvla_processor_cuda():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -238,7 +248,12 @@ def test_smolvla_processor_accelerate_scenario():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -289,7 +304,12 @@ def test_smolvla_processor_multi_gpu():
|
||||
return features
|
||||
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor", MockTokenizerProcessor):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -315,7 +335,12 @@ def test_smolvla_processor_without_stats():
|
||||
|
||||
# Mock the tokenizer processor
|
||||
with patch("lerobot.policies.smolvla.processor_smolvla.TokenizerProcessor"):
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(config, dataset_stats=None)
|
||||
preprocessor, postprocessor = make_smolvla_pre_post_processors(
|
||||
config,
|
||||
dataset_stats=None,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
|
||||
@@ -81,7 +81,12 @@ def test_make_tdmpc_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -105,7 +110,12 @@ def test_tdmpc_processor_normalization():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {
|
||||
@@ -138,7 +148,12 @@ def test_tdmpc_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -171,7 +186,12 @@ def test_tdmpc_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -198,7 +218,12 @@ def test_tdmpc_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -222,7 +247,22 @@ def test_tdmpc_processor_without_stats():
|
||||
"""Test TDMPC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -245,14 +285,21 @@ def test_tdmpc_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -276,7 +323,12 @@ def test_tdmpc_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -305,7 +357,12 @@ def test_tdmpc_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 64
|
||||
@@ -330,7 +387,12 @@ def test_tdmpc_processor_edge_cases():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with only state observation (no image)
|
||||
observation = {OBS_STATE: torch.randn(12)}
|
||||
|
||||
@@ -389,7 +389,7 @@ def test_integration_with_robot_processor(mock_auto_tokenizer):
|
||||
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
|
||||
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
robot_processor = RobotProcessor([tokenizer_processor])
|
||||
robot_processor = RobotProcessor([tokenizer_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
transition = create_transition(
|
||||
observation={"state": torch.tensor([1.0, 2.0])},
|
||||
@@ -427,14 +427,16 @@ def test_save_and_load_pretrained_with_tokenizer_name(mock_auto_tokenizer):
|
||||
tokenizer_name="test-tokenizer", max_length=32, task_key="instruction"
|
||||
)
|
||||
|
||||
robot_processor = RobotProcessor([original_processor])
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
robot_processor.save_pretrained(temp_dir)
|
||||
|
||||
# Load processor - tokenizer will be recreated from saved config
|
||||
loaded_processor = RobotProcessor.from_pretrained(temp_dir)
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
temp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
transition = create_transition(
|
||||
@@ -456,7 +458,7 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
|
||||
original_processor = TokenizerProcessor(tokenizer=mock_tokenizer, max_length=32, task_key="instruction")
|
||||
|
||||
robot_processor = RobotProcessor([original_processor])
|
||||
robot_processor = RobotProcessor([original_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Save processor
|
||||
@@ -464,7 +466,10 @@ def test_save_and_load_pretrained_with_tokenizer_object():
|
||||
|
||||
# Load processor with tokenizer override (since tokenizer object wasn't saved)
|
||||
loaded_processor = RobotProcessor.from_pretrained(
|
||||
temp_dir, overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}}
|
||||
temp_dir,
|
||||
overrides={"tokenizer_processor": {"tokenizer": mock_tokenizer}},
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
@@ -952,7 +957,9 @@ def test_integration_with_device_processor(mock_auto_tokenizer):
|
||||
# Create pipeline with TokenizerProcessor then DeviceProcessor
|
||||
tokenizer_processor = TokenizerProcessor(tokenizer_name="test-tokenizer", max_length=6)
|
||||
device_processor = DeviceProcessor(device="cuda:0")
|
||||
robot_processor = RobotProcessor([tokenizer_processor, device_processor])
|
||||
robot_processor = RobotProcessor(
|
||||
[tokenizer_processor, device_processor], to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Start with CPU tensors
|
||||
transition = create_transition(
|
||||
|
||||
@@ -81,7 +81,12 @@ def test_make_vqbet_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -105,7 +110,12 @@ def test_vqbet_processor_with_images():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data with images and states
|
||||
observation = {
|
||||
@@ -131,7 +141,12 @@ def test_vqbet_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {
|
||||
@@ -164,7 +179,12 @@ def test_vqbet_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU and batched
|
||||
device = torch.device("cuda:0")
|
||||
@@ -191,7 +211,12 @@ def test_vqbet_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -215,7 +240,22 @@ def test_vqbet_processor_without_stats():
|
||||
"""Test VQBeT processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_vqbet_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -238,14 +278,21 @@ def test_vqbet_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {
|
||||
@@ -269,7 +316,12 @@ def test_vqbet_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -298,7 +350,12 @@ def test_vqbet_processor_large_batch():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with large batch
|
||||
batch_size = 128
|
||||
@@ -323,7 +380,12 @@ def test_vqbet_processor_sequential_processing():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_vqbet_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Process multiple samples sequentially
|
||||
results = []
|
||||
|
||||
Reference in New Issue
Block a user