fix(pytorch audio format): switching to pytorch's default channel first format for audio

This commit is contained in:
CarolinePascal
2025-04-28 19:40:22 +02:00
parent b573b7a052
commit 8e29c530ed
+18 -13
View File
@@ -82,12 +82,7 @@ def decode_audio_torchcodec(
start_seconds=max(0.0, ts - duration), stop_seconds=ts
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
)
current_audio_chunk_data = current_audio_chunk.data.t()
current_audio_chunk_data = current_audio_chunk.data
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
@@ -95,16 +90,22 @@ def decode_audio_torchcodec(
if ts < 1 / audio_sample_rate:
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.zeros(
(int(ceil(duration * audio_sample_rate)), audio_channels)
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0)
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
logging.info(
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)
@@ -144,20 +145,24 @@ def decode_audio_torchaudio(
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
current_audio_chunk = reader.pop_chunks()[0]
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
# Case where the requested audio chunk starts before the beginning of the audio stream
if ts - duration < 0:
# No useful audio sample has been recorded
if ts < 1 / audio_sample_rate:
current_audio_chunk = torch.zeros((int(ceil(duration * audio_sample_rate)), audio_channels))
current_audio_chunk_data = torch.zeros(
(audio_channels, int(ceil(duration * audio_sample_rate)))
)
# At least one useful audio sample has been recorded
else:
# Remove the superfluous last samples of the audio chunk
current_audio_chunk = current_audio_chunk[: int(ceil(ts * audio_sample_rate))]
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
# Pad the beginning of the audio chunk with zeros
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
current_audio_chunk = torch.nn.functional.pad(
current_audio_chunk, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0)
current_audio_chunk_data = torch.nn.functional.pad(
current_audio_chunk_data,
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
)
if log_loaded_timestamps:
@@ -165,7 +170,7 @@ def decode_audio_torchaudio(
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
)
audio_chunks.append(current_audio_chunk)
audio_chunks.append(current_audio_chunk_data)
audio_chunks = torch.stack(audio_chunks)