fix(normalization): restricting 255 normalization to non depth/uint8 images only

This commit is contained in:
CarolinePascal
2026-05-26 14:17:03 +02:00
parent ba7f23adf9
commit e961f8fec0
4 changed files with 9 additions and 5 deletions
+3 -2
View File
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
def prepare_image(image: torch.Tensor) -> torch.Tensor:
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
image = image.type(torch.float32) / 255
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
if image.dtype == torch.uint8:
image = image.type(torch.float32) / 255
image = image.contiguous()
return image
+2 -1
View File
@@ -531,8 +531,9 @@ def compute_episode_stats(
)
if features[key]["dtype"] in ["image", "video"]:
normalization_factor = 255.0 if key not in features.depth_keys else 1.0
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
+2 -1
View File
@@ -299,10 +299,11 @@ class DatasetWriter:
if use_streaming:
streaming_results = self._streaming_encoder.finish_episode()
for video_key in self._meta.video_keys:
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
temp_path, video_stats = streaming_results[video_key]
if video_stats is not None:
ep_stats[video_key] = {
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
for k, v in video_stats.items()
}
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
+2 -1
View File
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
for name in observation:
observation[name] = torch.from_numpy(observation[name])
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
if observation[name].dtype == torch.uint8:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)