Skip to content

Commit

Permalink
0.91.3
Browse files Browse the repository at this point in the history
  • Loading branch information
FBurkhardt committed Nov 5, 2024
1 parent 2cde386 commit 5d0acb5
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Version 0.91.3
--------------
* some additions for robustness

Version 0.91.2
--------------
* making lint work by excluding constants from check
Expand Down
43 changes: 43 additions & 0 deletions nkululeko/autopredict/ap_sid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
""""
A predictor for sid - Speaker ID.
"""

from pyannote.audio import Pipeline


import numpy as np

import nkululeko.glob_conf as glob_conf
from nkululeko.feature_extractor import FeatureExtractor
from nkululeko.utils.util import Util


class SIDPredictor:
"""SIDPredictor.
predicting speaker id.
"""

def __init__(self, df):
self.df = df
self.util = Util("sidPredictor")
self.pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token="HUGGINGFACE_ACCESS_TOKEN_GOES_HERE",
)

def predict(self, split_selection):
self.util.debug(f"estimating PESQ for {split_selection} samples")
return_df = self.df.copy()
feats_name = "_".join(ast.literal_eval(glob_conf.config["DATA"]["databases"]))
self.feature_extractor = FeatureExtractor(
self.df, ["squim"], feats_name, split_selection
)
result_df = self.feature_extractor.extract()
# replace missing values by 0
result_df = result_df.fillna(0)
result_df = result_df.replace(np.nan, 0)
result_df.replace([np.inf, -np.inf], 0, inplace=True)
pred_vals = result_df.pesq * 100
return_df["pesq_pred"] = pred_vals.astype("int") / 100
return return_df
2 changes: 1 addition & 1 deletion nkululeko/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION="0.91.2"
VERSION="0.91.3"
SAMPLING_RATE = 16000
8 changes: 7 additions & 1 deletion nkululeko/data/dataset_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def load(self):
df = audformat.utils.read_csv(data_file)
if isinstance(df, pd.Series):
df = df.to_frame()
elif isinstance(df, pd.Index):
df = pd.DataFrame(index=df)
rename_cols = self.util.config_val_data(self.name, "colnames", False)
if rename_cols:
col_dict = ast.literal_eval(rename_cols)
Expand Down Expand Up @@ -78,7 +80,11 @@ def load(self):

self.df = df
self.db = None
self.got_target = True
target = self.util.config_val("DATA", "target", None)
if target is not None:
self.got_target = True
else:
self.got_target = False
self.is_labeled = self.got_target
self.start_fresh = eval(self.util.config_val("DATA", "no_reuse", "False"))
is_index = False
Expand Down
10 changes: 9 additions & 1 deletion nkululeko/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,15 @@ def _import_csv(self, storage):
# df = pd.read_csv(storage, header=0, index_col=[0,1,2])
# df.index.set_levels(pd.to_timedelta(df.index.levels[1]), level=1)
# df.index.set_levels(pd.to_timedelta(df.index.levels[2]), level=2)
df = audformat.utils.read_csv(storage)
try:
df = audformat.utils.read_csv(storage)
except ValueError:
# split might be empty
return pd.DataFrame()
if isinstance(df, pd.Series):
df = df.to_frame()
elif isinstance(df, pd.Index):
df = pd.DataFrame(index=df)
df.is_labeled = True if self.target in df else False
# print(df.head())
return df
Expand Down
3 changes: 3 additions & 0 deletions nkululeko/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def plot_durations(self, df, filename, sample_selection, caption=""):
except AttributeError as ae:
self.util.warn(ae)
ax = sns.histplot(df, x="duration", kde=True)
except ValueError as error:
self.util.warn(error)
ax = sns.histplot(df, x="duration", kde=True)
min = self.util.to_3_digits(df.duration.min())
max = self.util.to_3_digits(df.duration.max())
title = f"Duration distr. for {sample_selection} {df.shape[0]}. min={min}, max={max}"
Expand Down
2 changes: 1 addition & 1 deletion nkululeko/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def calc_dur(x):
)
print("")
# remove encoded labels
target = util.config_val("DATA", "target", "emotion")
target = util.config_val("DATA", "target", None)
if "class_label" in df_seg.columns:
df_seg = df_seg.drop(columns=[target])
df_seg = df_seg.rename(columns={"class_label": target})
Expand Down

0 comments on commit 5d0acb5

Please sign in to comment.