Learning from Sound

technical
intermediate
audio
Classifying respiratory sounds with PyTorch and torchaudio
Published

September 14, 2021

Some utility functions for this notebook
# To be used with torchaudio
def print_stats(waveform, sample_rate=None, src=None):
  if src:
    print("-" * 10)
    print("Source:", src)
    print("-" * 10)
  if sample_rate:
    print("Sample Rate:", sample_rate)
  print("Shape:", tuple(waveform.shape))
  print("Dtype:", waveform.dtype)
  print(f" - Max:     {waveform.max().item():6.3f}")
  print(f" - Min:     {waveform.min().item():6.3f}")
  print(f" - Mean:    {waveform.mean().item():6.3f}")
  print(f" - Std Dev: {waveform.std().item():6.3f}")
  print()
  print(waveform)
  print()

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None, ylim=None):
  waveform = waveform.numpy()

  num_channels, num_frames = waveform.shape
  time_axis = torch.arange(0, num_frames) / sample_rate

  figure, axes = plt.subplots(num_channels, 1)
  if num_channels == 1:
    axes = [axes]
  for c in range(num_channels):
    axes[c].plot(time_axis, waveform[c], linewidth=1)
    axes[c].grid(True)
    if num_channels > 1:
      axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
      axes[c].set_xlim(xlim)
    if ylim:
      axes[c].set_ylim(ylim)
  figure.suptitle(title)
  plt.show(block=False)

def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
  waveform = waveform.numpy()

  num_channels, num_frames = waveform.shape
  time_axis = torch.arange(0, num_frames) / sample_rate

  figure, axes = plt.subplots(num_channels, 1)
  if num_channels == 1:
    axes = [axes]
  for c in range(num_channels):
    axes[c].specgram(waveform[c], Fs=sample_rate)
    if num_channels > 1:
      axes[c].set_ylabel(f'Channel {c+1}')
    if xlim:
      axes[c].set_xlim(xlim)
  figure.suptitle(title)
  plt.show(block=False)

def play_audio(waveform, sample_rate):
  waveform = waveform.numpy()

  num_channels, num_frames = waveform.shape
  if num_channels == 1:
    display(Audio(waveform[0], rate=sample_rate))
  elif num_channels == 2:
    display(Audio((waveform[0], waveform[1]), rate=sample_rate))
  else:
    raise ValueError("Waveform with more than 2 channels are not supported.")

def inspect_file(path):
  print("-" * 10)
  print("Source:", path)
  print("-" * 10)
  print(f" - File size: {os.path.getsize(path)} bytes")
  print(f" - {torchaudio.info(path)}")

def plot_spectrogram(spec, title=None, ylabel='freq_bin', aspect='auto', xmax=None):
  fig, axs = plt.subplots(1, 1)
  axs.set_title(title or 'Spectrogram (db)')
  axs.set_ylabel(ylabel)
  axs.set_xlabel('frame')
  im = axs.imshow(librosa.power_to_db(spec), origin='lower', aspect=aspect)
  if xmax:
    axs.set_xlim((0, xmax))
  fig.colorbar(im, ax=axs)
  plt.show(block=False)

The experimental version of this notebook can be found in this repo: Learning from Sound - Experimental

This notebook assumes basic knowledge about training neural networks, what a CNN is and other deep learning knowledge such as batchnorm and basic knowledge of sound represented in digital format.

To learn about the latter, you can go through this 6-part blog series that goes from the beginning explaining about the issue. (The first four posts will be sufficient for this notebook)

Imports
import random
from collections import Counter
import librosa

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

from fastcore.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler

import torchaudio
import torchaudio.transforms as T

from IPython.display import Audio, display

Introduction

Respiratory sounds are important indicators of respiratory health and respiratory disorders. The sound emitted when a person breathes is directly related to air movement, changes within lung tissue and the position of secretions within the lung. A wheezing sound, for example, is a common sign that a patient has an obstructive airway disease like asthma or chronic obstructive pulmonary disease (COPD).

These sounds can be recorded using digital stethoscopes and other recording techniques. This digital data opens up the possibility of using machine learning to automatically diagnose respiratory disorders like asthma, pneumonia and bronchiolitis, to name a few.

In this notebook, we are going to try and create a Convolutional Neural Network that can distinguish and classify different respiratory sounds and make a diagnosis. In the process, we are going to learn about how sound is represented in digital format, converting the audio files into spectrograms, which the CNN can use to learn from and a few other random things about training neural networks.

I learned a lot from other people while making this notebook and I reference all of them at the bottom.

Getting the Data

Luckily for us, two research teams in Portugal and Greece already prepared a suitable dataset that can be found on Kaggle. It includes 920 annotated recordings of varying length - 10s to 90s. These recordings were taken from 126 patients. There are a total of 5.5 hours of recordings containing 6898 respiratory cycles.

We can download the dataset using the kaggle command.

!kaggle datasets download -d vbookshelf/respiratory-sound-database
Downloading respiratory-sound-database.zip to /content
100% 3.68G/3.69G [01:38<00:00, 24.1MB/s]
100% 3.69G/3.69G [01:38<00:00, 40.2MB/s]

Working with torchaudio

We are going to be using PyTorch and torchaudio in this notebook.

Let’s create a pathlib object pointing to where our data is located:

data_path = Path('data/respiratory_sound_database/Respiratory_Sound_Database')

We can see what files are present in out data_path

data_path.ls()
(#4) [Path('patient_diagnosis.csv'),Path('audio_and_txt_files'),Path('filename_format.txt'),Path('filename_differences.txt')]

And get one file to use our example:

(data_path/'audio_and_txt_files').ls(file_exts='.wav')[0]
Path('audio_and_txt_files/138_1p2_Ar_mc_AKGC417L.wav')
AUDIO_FILE = (data_path/'audio_and_txt_files').ls(file_exts='.wav')[0]

Let us load that audio file using torchaudio. It returns a tuple containing the waveform and its sample rate.

waveform, sample_rate = torchaudio.load(AUDIO_FILE)
waveform.shape, sample_rate
(torch.Size([1, 882000]), 44100)

Our example audio file has a shape of [1, 882000] and a sample rate of 44100 kHz which is pretty common.

Other info about the audio file can be seen using the following handy utility function:

print_stats(waveform)
Shape: (1, 882000)
Dtype: torch.float32
 - Max:      0.899
 - Min:     -0.623
 - Mean:     0.000
 - Std Dev:  0.112

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0847, 0.0853, 0.0724]])

We can plot the waveform of the audio file:

plot_waveform(waveform, sample_rate);

As you can see, the waveform is still a signal. A CNN expects an image-like input. So we need a way to convert the above signal to an image. A Spectrogram is a visual representation of spectrum of frequencies of a signal as it varies with time

Here is a spectrogram of the example audio above:

plot_specgram(waveform, sample_rate);
/usr/local/lib/python3.7/dist-packages/matplotlib/axes/_axes.py:7592: RuntimeWarning: divide by zero encountered in log10
  Z = 10. * np.log10(spec)

As you can see, just and ordinary spectrogram won’t give our CNN much to learn from. Mel Spectrograms work better in this case. Converting a Spectrogram to a Mel spectogram is easy in PyTorch.

n_fft = 1024
win_length = None
hop_length = 512
n_mels = 128

mel_spectrogram = T.MelSpectrogram(
    sample_rate=sample_rate,
    n_fft=n_fft,
    win_length=win_length,
    hop_length=hop_length,
    center=True,
    pad_mode="reflect",
    power=2.0,
    norm='slaney',
    onesided=True,
    n_mels=n_mels,
    mel_scale="htk",
)

melspec = mel_spectrogram(waveform)
plot_spectrogram(
    melspec[0], title="MelSpectrogram", ylabel='mel freq');
/usr/local/lib/python3.7/dist-packages/torchaudio/functional/functional.py:433: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (513) may be set too low.
  "At least one mel filterbank has all zero values. "

This is a better visual representation than ordinary spectrograms and gives our neural network something to work with.

Finally, we can play the audio and hear the respiratory recording.

play_audio(waveform, sample_rate);