switch normalization order pipeline for pi05

This commit is contained in:
Pepijn
2025-09-25 23:38:06 +02:00
parent a196639c09
commit 18db813e2d
+9 -7
View File
@@ -71,7 +71,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
@@ -84,7 +84,7 @@ class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
@@ -133,6 +133,13 @@ def make_pi05_pre_post_processors(
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
@@ -141,11 +148,6 @@ def make_pi05_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = [