mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
fix(pytorch audio format): switching to pytorch's default channel first format for audio
This commit is contained in:
@@ -82,12 +82,7 @@ def decode_audio_torchcodec(
|
|||||||
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
start_seconds=max(0.0, ts - duration), stop_seconds=ts
|
||||||
)
|
)
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
current_audio_chunk_data = current_audio_chunk.data
|
||||||
logging.info(
|
|
||||||
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_audio_chunk_data = current_audio_chunk.data.t()
|
|
||||||
|
|
||||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||||
if ts - duration < 0:
|
if ts - duration < 0:
|
||||||
@@ -95,16 +90,22 @@ def decode_audio_torchcodec(
|
|||||||
if ts < 1 / audio_sample_rate:
|
if ts < 1 / audio_sample_rate:
|
||||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
current_audio_chunk_data = torch.zeros(
|
current_audio_chunk_data = torch.zeros(
|
||||||
(int(ceil(duration * audio_sample_rate)), audio_channels)
|
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||||
)
|
)
|
||||||
# At least one useful audio sample has been recorded
|
# At least one useful audio sample has been recorded
|
||||||
else:
|
else:
|
||||||
# Pad the beginning of the audio chunk with zeros
|
# Pad the beginning of the audio chunk with zeros
|
||||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
current_audio_chunk_data = torch.nn.functional.pad(
|
current_audio_chunk_data = torch.nn.functional.pad(
|
||||||
current_audio_chunk_data, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0)
|
current_audio_chunk_data,
|
||||||
|
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if log_loaded_timestamps:
|
||||||
|
logging.info(
|
||||||
|
f"audio chunk loaded at timestamp={current_audio_chunk.pts_seconds:.4f} with duration={current_audio_chunk.duration_seconds:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
audio_chunks.append(current_audio_chunk_data)
|
audio_chunks.append(current_audio_chunk_data)
|
||||||
|
|
||||||
audio_chunks = torch.stack(audio_chunks)
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
@@ -144,20 +145,24 @@ def decode_audio_torchaudio(
|
|||||||
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
logging.warning("Audio stream reached end of recording before decoding desired timestamps.")
|
||||||
|
|
||||||
current_audio_chunk = reader.pop_chunks()[0]
|
current_audio_chunk = reader.pop_chunks()[0]
|
||||||
|
current_audio_chunk_data = current_audio_chunk.t() # Channel first format
|
||||||
|
|
||||||
# Case where the requested audio chunk starts before the beginning of the audio stream
|
# Case where the requested audio chunk starts before the beginning of the audio stream
|
||||||
if ts - duration < 0:
|
if ts - duration < 0:
|
||||||
# No useful audio sample has been recorded
|
# No useful audio sample has been recorded
|
||||||
if ts < 1 / audio_sample_rate:
|
if ts < 1 / audio_sample_rate:
|
||||||
current_audio_chunk = torch.zeros((int(ceil(duration * audio_sample_rate)), audio_channels))
|
current_audio_chunk_data = torch.zeros(
|
||||||
|
(audio_channels, int(ceil(duration * audio_sample_rate)))
|
||||||
|
)
|
||||||
# At least one useful audio sample has been recorded
|
# At least one useful audio sample has been recorded
|
||||||
else:
|
else:
|
||||||
# Remove the superfluous last samples of the audio chunk
|
# Remove the superfluous last samples of the audio chunk
|
||||||
current_audio_chunk = current_audio_chunk[: int(ceil(ts * audio_sample_rate))]
|
current_audio_chunk_data = current_audio_chunk_data[:, : int(ceil(ts * audio_sample_rate))]
|
||||||
# Pad the beginning of the audio chunk with zeros
|
# Pad the beginning of the audio chunk with zeros
|
||||||
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
# TODO(CarolinePascal) : add low level white noise instead of zeros ?
|
||||||
current_audio_chunk = torch.nn.functional.pad(
|
current_audio_chunk_data = torch.nn.functional.pad(
|
||||||
current_audio_chunk, (0, 0, int(ceil((duration - ts) * audio_sample_rate)), 0)
|
current_audio_chunk_data,
|
||||||
|
(int(ceil((duration - ts) * audio_sample_rate)), 0, 0, 0), # left, right, top, bottom
|
||||||
)
|
)
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
@@ -165,7 +170,7 @@ def decode_audio_torchaudio(
|
|||||||
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
f"audio chunk loaded at starting timestamp={current_audio_chunk['pts']:.4f} with duration={len(current_audio_chunk) / audio_sample_rate:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_chunks.append(current_audio_chunk)
|
audio_chunks.append(current_audio_chunk_data)
|
||||||
|
|
||||||
audio_chunks = torch.stack(audio_chunks)
|
audio_chunks = torch.stack(audio_chunks)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user