Adding pytorch compatible conversion for audio

This commit is contained in:
CarolinePascal
2025-04-04 18:31:00 +02:00
parent 5276fc0d6f
commit cdd3a859ef
3 changed files with 8 additions and 5 deletions
+3 -3
View File
@@ -78,8 +78,9 @@ def decode_audio_torchaudio(
# TODO(CarolinePascal) : sort timestamps ?
reader.add_basic_audio_stream(
frames_per_chunk=int(ceil(duration * audio_sampling_rate)), # Too much is better than not enough
buffer_chunk_size=-1, # No dropping frames
frames_per_chunk = int(ceil(duration * audio_sampling_rate)), #Too much is better than not enough
buffer_chunk_size = -1, #No dropping frames
format = "fltp", #Format as float32
)
audio_chunks = []
@@ -99,7 +100,6 @@ def decode_audio_torchaudio(
audio_chunks.append(current_audio_chunk)
audio_chunks = torch.stack(audio_chunks)
# TODO(CarolinePascal) : pytorch format conversion ?
assert len(timestamps) == len(audio_chunks)
return audio_chunks
+4 -1
View File
@@ -106,7 +106,7 @@ def prepare_observation_for_inference(
This function takes a dictionary of NumPy arrays, performs necessary
preprocessing, and prepares it for model inference. The steps include:
1. Converting NumPy arrays to PyTorch tensors.
2. Normalizing and permuting image data (if any).
2. Normalizing and permuting image data and audio data (if any).
3. Adding a batch dimension to each tensor.
4. Moving all tensors to the specified compute device.
5. Adding task and robot type information to the dictionary.
@@ -129,6 +129,9 @@ def prepare_observation_for_inference(
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
elif "audio" in name:
observation[name] = observation[name].type(torch.float32)
observation[name] = observation[name].permute(1, 0).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
+1 -1
View File
@@ -102,7 +102,7 @@ def predict_action(
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
# Convert to pytorch format: normalizing and permuting (channel first)
observation = prepare_observation_for_inference(observation, device, task, robot_type)
observation = preprocessor(observation)