diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 4931c68c5..54f0ca69f 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -105,8 +105,9 @@ def raw_observation_to_observation( def prepare_image(image: torch.Tensor) -> torch.Tensor: - """Minimal preprocessing to turn int8 images to float32 in [0, 1], and create a memory-contiguous tensor""" - image = image.type(torch.float32) / 255 + """Minimal preprocessing to turn RGB uint8 images to float32 in [0, 1], and create a memory-contiguous tensor""" + if image.dtype == torch.uint8: + image = image.type(torch.float32) / 255 image = image.contiguous() return image diff --git a/src/lerobot/datasets/compute_stats.py b/src/lerobot/datasets/compute_stats.py index b2438f1d7..fc9a4c562 100644 --- a/src/lerobot/datasets/compute_stats.py +++ b/src/lerobot/datasets/compute_stats.py @@ -531,8 +531,9 @@ def compute_episode_stats( ) if features[key]["dtype"] in ["image", "video"]: + normalization_factor = 255.0 if key not in features.depth_keys else 1.0 ep_stats[key] = { - k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + k: v if k == "count" else np.squeeze(v / normalization_factor, axis=0) for k, v in ep_stats[key].items() } return ep_stats diff --git a/src/lerobot/datasets/dataset_writer.py b/src/lerobot/datasets/dataset_writer.py index c8655555e..deb60b259 100644 --- a/src/lerobot/datasets/dataset_writer.py +++ b/src/lerobot/datasets/dataset_writer.py @@ -299,10 +299,11 @@ class DatasetWriter: if use_streaming: streaming_results = self._streaming_encoder.finish_episode() for video_key in self._meta.video_keys: + normalization_factor = 255.0 if video_key not in self._meta.depth_keys else 1.0 temp_path, video_stats = streaming_results[video_key] if video_stats is not None: ep_stats[video_key] = { - k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / normalization_factor, axis=0) for k, v in video_stats.items() } ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index c37127813..f465fcff8 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -126,7 +126,8 @@ def prepare_observation_for_inference( for name in observation: observation[name] = torch.from_numpy(observation[name]) if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 + if observation[name].dtype == torch.uint8: + observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() observation[name] = observation[name].unsqueeze(0) observation[name] = observation[name].to(device)