mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(normalization): restricting 255 normalization to non depth/uint8 images only
This commit is contained in:
@@ -105,8 +105,9 @@ def raw_observation_to_observation(
|
||||
|
||||
|
||||
def prepare_image(image: torch.Tensor) -> torch.Tensor:
|
||||
"""Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
image = image.type(torch.float32) / 255
|
||||
"""Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor"""
|
||||
if image.dtype == torch.uint8:
|
||||
image = image.type(torch.float32) / 255
|
||||
image = image.contiguous()
|
||||
|
||||
return image
|
||||
|
||||
@@ -531,8 +531,9 @@ def compute_episode_stats(
|
||||
)
|
||||
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
normalization_factor = 255.0 if key not in features.depth_keys else 1.0
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
|
||||
return ep_stats
|
||||
|
||||
@@ -299,10 +299,11 @@ class DatasetWriter:
|
||||
if use_streaming:
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self._meta.video_keys:
|
||||
normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
|
||||
@@ -126,7 +126,8 @@ def prepare_observation_for_inference(
|
||||
for name in observation:
|
||||
observation[name] = torch.from_numpy(observation[name])
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
if observation[name].dtype == torch.uint8:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
Reference in New Issue
Block a user