Skip to content

Commit

Permalink
Merge pull request #95 from felixbur/cnn
Browse files Browse the repository at this point in the history
Cnn
  • Loading branch information
felixbur authored Nov 16, 2023
2 parents 5ee38c1 + e3e5a99 commit db15bb2
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 137 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.69.0
--------------
* added CNN model and melspec extractor

Version 0.68.4
--------------
* bugfix: got_gender was uncorrectly set
Expand Down
2 changes: 1 addition & 1 deletion ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
* **xgr**: XG-Boost Regression
* **mlp**: Multi-Layer-Perceptron for classification
* **mlp_reg**: Multi-Layer-Perceptron for regression
* **cnn**: Convolutional neural network (tbd)
* **cnn**: Convolutional neural network (only works with feature type=spectra)
* **tuning_params**: possible tuning parameters for x-fold optimization (for Bayes, KNN, KNN_reg, Tree, Tree_reg, SVM, SVR, XGB and XGR)
* tuning_params = ['subsample', 'n_estimators', 'max_depth']
* subsample = [.5, .7]
Expand Down
89 changes: 89 additions & 0 deletions nkululeko/feat_extract/feats_spectra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
feats_spectra.py
Inspired by code from Su Lei
"""
import os
import torchaudio
import torchaudio.transforms as T
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pathlib
import audeer

from nkululeko.feat_extract.featureset import Featureset
from nkululeko.constants import SAMPLING_RATE
import nkululeko.glob_conf as glob_conf


class Spectraloader(Featureset):
def __init__(self, name, data_df):
"""Constructor setting the name"""
Featureset.__init__(self, name, data_df)
self.sampling_rate = SAMPLING_RATE
self.num_bands = int(self.util.config_val("FEATS", "fft_nbands", "64"))
self.win_dur = int(self.util.config_val("FEATS", "fft_win_dur", "25"))
self.hop_dur = int(self.util.config_val("FEATS", "fft_hop_dur", "10"))

def extract(self):
"""Extract the features or load them from disk if present."""
store = self.util.get_path("store")
store_format = self.util.config_val("FEATS", "store_format", "pkl")
storage = f"{store}{self.name}.{store_format}"
extract = self.util.config_val("FEATS", "needs_feature_extraction", False)
no_reuse = eval(self.util.config_val("FEATS", "no_reuse", "False"))
if extract or no_reuse or not os.path.isfile(storage):
self.util.debug("extracting mel spectra, this might take a while...")
image_store = audeer.mkdir(f"{store}{self.name}")
images = []
for idx, (file, start, end) in enumerate(
tqdm(self.data_df.index.to_list())
):
signal, sampling_rate = torchaudio.load(
file,
frame_offset=int(start.total_seconds() * 16000),
num_frames=int((end - start).total_seconds() * 16000),
)
assert sampling_rate == 16000, f"got {sampling_rate} instead of 16000"
image = self._waveform2rgb(signal)
outfile = f"{image_store}/{pathlib.Path(file).stem}_{idx}.jpg"
image.save(outfile)
images.append(outfile)
self.df = pd.DataFrame(images, index=self.data_df.index)
self.util.write_store(self.df, storage, store_format)
try:
glob_conf.config["DATA"]["needs_feature_extraction"] = "false"
except KeyError:
pass
else:
self.util.debug("reusing extracted spectrograms")
self.df = self.util.get_store(storage, store_format)

def _waveform2rgb(self, waveform, target_size=(256, 256)):
# Transform to spectrogram
spectrogram = T.MelSpectrogram(
sample_rate=SAMPLING_RATE,
n_mels=self.num_bands,
hop_length=int(self.hop_dur * SAMPLING_RATE / 1000),
win_length=int(self.win_dur * SAMPLING_RATE / 1000),
)(waveform)
melspec = T.AmplitudeToDB()(spectrogram)[0].numpy()
melspec_norm = (melspec - np.min(melspec)) / (np.max(melspec) - np.min(melspec))

# Map normalized Mel spectrogram to viridis colormap
cmapped = plt.get_cmap("viridis")(melspec_norm)

# Convert this colormap representation to a format suitable for creating a PIL Image
image_array = (cmapped[:, :, :3] * 255).astype(np.uint8)
image = Image.fromarray(image_array, mode="RGB")
image = ImageOps.flip(image)

# Resize to target size
image = image.resize(target_size, Image.Resampling.LANCZOS)
return image
12 changes: 8 additions & 4 deletions nkululeko/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,16 @@ def extract(self):
store_name = f"{self.data_name}_{feats_type}"
if feats_type == "os":
from nkululeko.feat_extract.feats_opensmile import Opensmileset

self.featExtractor = Opensmileset(
f"{store_name}_{self.feats_designation}", self.data_df
)
elif feats_type == "spectra":
from nkululeko.feat_extract.feats_spectra import Spectraloader

self.featExtractor = Spectraloader(
f"{store_name}_{self.feats_designation}", self.data_df
)
elif feats_type == "trill":
from nkululeko.feat_extract.feats_trill import TRILLset

Expand Down Expand Up @@ -85,7 +92,6 @@ def extract(self):
self.data_df,
feats_type,
)

elif feats_type == "audmodel":
from nkululeko.feat_extract.feats_audmodel import AudModelSet

Expand Down Expand Up @@ -166,9 +172,7 @@ def extract(self):
# remove samples that were not extracted by MLD
# self.df_test = self.df_test.loc[self.df_test.index.intersection(featExtractor_test.df.index)]
# self.df_train = self.df_train.loc[self.df_train.index.intersection(featExtractor_train.df.index)]
self.util.debug(
f"{feats_type}: shape : {self.featExtractor.df.shape}"
)
self.util.debug(f"{feats_type}: shape : {self.featExtractor.df.shape}")
self.feats = pd.concat([self.feats, self.featExtractor.df], axis=1)
return self.feats

Expand Down
5 changes: 1 addition & 4 deletions nkululeko/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def do_epochs(self):
self.model.train()
report = self.model.predict()
report.set_id(self.run, epoch)
plot_name = (
self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
)
plot_name = self.util.get_plot_name() + f"_{self.run}_{epoch:03d}_cnf"
reports.append(report)
self.util.debug(
f"run: {self.run} epoch: {epoch}: result: "
Expand Down Expand Up @@ -132,7 +130,6 @@ def _select_model(self, model_type):
)
elif model_type == "cnn":
from nkululeko.models.model_cnn import CNN_model
from nkululeko.models.model_cnn import CNN_model

self.model = CNN_model(
self.df_train, self.df_test, self.feats_train, self.feats_test
Expand Down
Loading

0 comments on commit db15bb2

Please sign in to comment.