From 8e29c530ed6ef1fde240e9ea30f2917914f83881 Mon Sep 17 00:00:00 2001 From: CarolinePascal Date: Mon, 28 Apr 2025 19:40:22 +0200 Subject: [PATCH] fix(pytorch audio format): switching to pytorch's default channel first format for audio --- src/lerobot/datasets/audio_utils.py | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/lerobot/datasets/audio_utils.py b/src/lerobot/datasets/audio_utils.py index 342c8e48a..ef4f39c13 100644 --- a/src/lerobot/datasets/audio_utils.py +++ b/src/lerobot/datasets/audio_utils.py @@ -82,12 +82,7 @@ def decode_audio_torchcodec( start_seconds=max(0.0, ts - duration), stop_seconds=ts ) - if log_loaded_timestamps: - logging.info( - f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}" - ) - - current_audio_chunk_data = current_audio_chunk.data.t() + current_audio_chunk_data = current_audio_chunk.data # Case where the requested audio chunk starts before the beginning of the audio stream if ts - duration < 0: @@ -95,16 +90,22 @@ def decode_audio_torchcodec( if ts < 1 / audio_sample_rate: # TODO(CarolinePascal) : add low level white noise instead of zeros ? current_audio_chunk_data = torch.zeros( - (int(ceil(duration * audio_sample_rate)), audio_channels) + (audio_channels, int(ceil(duration * audio_sample_rate))) ) # At least one useful audio sample has been recorded else: # Pad the beginning of the audio chunk with zeros # TODO(CarolinePascal) : add low level white noise instead of zeros ? current_audio_chunk_data = torch.nn.functional.pad( - current_audio_chunk_data, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0) + current_audio_chunk_data, + (int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom ) + if log_loaded_timestamps: + logging.info( + f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}" + ) + audio_chunks.append(current_audio_chunk_data) audio_chunks = torch.stack(audio_chunks) @@ -144,20 +145,24 @@ def decode_audio_torchaudio( logging.warning("Audio stream reached end of recording before decoding desired timestamps.") current_audio_chunk = reader.pop_chunks()[0] + current_audio_chunk_data = current_audio_chunk.t() # Channel first format # Case where the requested audio chunk starts before the beginning of the audio stream if ts - duration < 0: # No useful audio sample has been recorded if ts < 1 / audio_sample_rate: - current_audio_chunk = torch.zeros((int(ceil(duration * audio_sample_rate)), audio_channels)) + current_audio_chunk_data = torch.zeros( + (audio_channels, int(ceil(duration * audio_sample_rate))) + ) # At least one useful audio sample has been recorded else: # Remove the superfluous last samples of the audio chunk - current_audio_chunk = current_audio_chunk[: int(ceil(ts * audio_sample_rate))] + current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))] # Pad the beginning of the audio chunk with zeros # TODO(CarolinePascal) : add low level white noise instead of zeros ? - current_audio_chunk = torch.nn.functional.pad( - current_audio_chunk, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0) + current_audio_chunk_data = torch.nn.functional.pad( + current_audio_chunk_data, + (int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom ) if log_loaded_timestamps: @@ -165,7 +170,7 @@ def decode_audio_torchaudio( f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}" ) - audio_chunks.append(current_audio_chunk) + audio_chunks.append(current_audio_chunk_data) audio_chunks = torch.stack(audio_chunks)