From 70165cb42de94da9d3c81fdfd181a4ec29b1cdf4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 6 Oct 2024 16:37:29 +0800 Subject: [PATCH] Speaker diarization example with onnxruntime Python API (#1395) --- .github/workflows/speaker-diarization.yaml | 98 ++++ .gitignore | 2 + scripts/pyannote/segmentation/README.md | 44 ++ .../segmentation/speaker-diarization-onnx.py | 488 ++++++++++++++++++ .../segmentation/speaker-diarization-torch.py | 86 +++ sherpa-onnx/csrc/fast-clustering.cc | 2 +- 6 files changed, 719 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/speaker-diarization.yaml create mode 100644 scripts/pyannote/segmentation/README.md create mode 100755 scripts/pyannote/segmentation/speaker-diarization-onnx.py create mode 100755 scripts/pyannote/segmentation/speaker-diarization-torch.py diff --git a/.github/workflows/speaker-diarization.yaml b/.github/workflows/speaker-diarization.yaml new file mode 100644 index 000000000..0bd6a575c --- /dev/null +++ b/.github/workflows/speaker-diarization.yaml @@ -0,0 +1,98 @@ +name: speaker-diarization + +on: + push: + branches: + - speaker-diarization + workflow_dispatch: + +concurrency: + group: speaker-diarization-${{ github.ref }} + cancel-in-progress: true + +jobs: + linux: + name: speaker diarization + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-speaker-diarization + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install pyannote + shell: bash + run: | + pip install pyannote.audio onnx onnxruntime + + - name: Install sherpa-onnx from source + shell: bash + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel twine setuptools + + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + + cat sherpa-onnx/python/sherpa_onnx/__init__.py + + python3 setup.py bdist_wheel + ls -lh dist + pip install ./dist/*.whl + + - name: Run tests + shell: bash + run: | + pushd scripts/pyannote/segmentation + + python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)" + python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)" + python3 -c "import sherpa_onnx; print(dir(sherpa_onnx))" + + curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin + + test_wavs=( + 0-two-speakers-zh.wav + 1-two-speakers-en.wav + 2-two-speakers-en.wav + 3-two-speakers-en.wav + ) + + for w in ${test_wavs[@]}; do + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$w + done + + soxi *.wav + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + ls -lh sherpa-onnx-pyannote-segmentation-3-0 + + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + + for w in ${test_wavs[@]}; do + echo "---------test $w (onnx)----------" + time ./speaker-diarization-onnx.py \ + --seg-model ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \ + --speaker-embedding-model ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \ + --wav $w + + echo "---------test $w (torch)----------" + time ./speaker-diarization-torch.py --wav $w + done diff --git a/.gitignore b/.gitignore index bf8ca193f..b0fbfae78 100644 --- a/.gitignore +++ b/.gitignore @@ -118,3 +118,5 @@ vits-melo-tts-zh_en *.o *.ppu sherpa-onnx-online-punct-en-2024-08-06 +*.mp4 +*.mp3 diff --git a/scripts/pyannote/segmentation/README.md b/scripts/pyannote/segmentation/README.md new file mode 100644 index 000000000..a2e35b2de --- /dev/null +++ b/scripts/pyannote/segmentation/README.md @@ -0,0 +1,44 @@ +# File description + +Please download test wave files from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models + +## 0-two-speakers-zh.wav + +This file is from +https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0 + +Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`. + +## 1-two-speakers-en.wav + +This file is from +https://github.com/pengzhendong/pyannote-onnx/blob/master/data/test_16k.wav +and it contains speeches from two speakers. + +Note that we have renamed it from `test_16k.wav` to `1-two-speakers-en.wav` + + +## 2-two-speakers-en.wav +This file is from +https://huggingface.co/spaces/Xenova/whisper-speaker-diarization + +Note that the original file is `./fcf059e3-689f-47ec-a000-bdace87f0113.mp4`. +We use the following commands to convert it to `2-two-speakers-en.wav`. + +```bash +ffmpeg -i ./fcf059e3-689f-47ec-a000-bdace87f0113.mp4 -ac 1 -ar 16000 ./2-two-speakers-en.wav +``` + +## 3-two-speakers-en.wav + +This file is from +https://aws.amazon.com/blogs/machine-learning/deploy-a-hugging-face-pyannote-speaker-diarization-model-on-amazon-sagemaker-as-an-asynchronous-endpoint/ + +Note that the original file is `ML16091-Audio.mp3`. We use the following +commands to convert it to `3-two-speakers-en.wav` + + +```bash +sox ML16091-Audio.mp3 3-two-speakers-en.wav +``` diff --git a/scripts/pyannote/segmentation/speaker-diarization-onnx.py b/scripts/pyannote/segmentation/speaker-diarization-onnx.py new file mode 100755 index 000000000..eddd841db --- /dev/null +++ b/scripts/pyannote/segmentation/speaker-diarization-onnx.py @@ -0,0 +1,488 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +Please refer to +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml +for usages. +""" + +import argparse +from datetime import timedelta +from pathlib import Path +from typing import List + +import librosa +import numpy as np +import onnxruntime as ort +import sherpa_onnx +import soundfile as sf +from numpy.lib.stride_tricks import as_strided + + +class Segment: + def __init__( + self, + start, + end, + speaker, + ): + assert start < end + self.start = start + self.end = end + self.speaker = speaker + + def merge(self, other, gap=0.5): + assert self.speaker == other.speaker, (self.speaker, other.speaker) + if self.end < other.start and self.end + gap >= other.start: + return Segment(start=self.start, end=other.end, speaker=self.speaker) + elif other.end < self.start and other.end + gap >= self.start: + return Segment(start=other.start, end=self.end, speaker=self.speaker) + else: + return None + + @property + def duration(self): + return self.end - self.start + + def __str__(self): + s = f"{timedelta(seconds=self.start)}"[:-3] + s += " --> " + s += f"{timedelta(seconds=self.end)}"[:-3] + s += f" speaker_{self.speaker:02d}" + return s + + +def merge_segment_list(in_out: List[Segment], min_duration_off: float): + changed = True + while changed: + changed = False + for i in range(len(in_out)): + if i + 1 >= len(in_out): + continue + + new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off) + if new_segment is None: + continue + del in_out[i + 1] + in_out[i] = new_segment + changed = True + break + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--seg-model", + type=str, + required=True, + help="Path to model.onnx for segmentation", + ) + parser.add_argument( + "--speaker-embedding-model", + type=str, + required=True, + help="Path to model.onnx for speaker embedding extractor", + ) + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") + + return parser.parse_args() + + +class OnnxSegmentationModel: + def __init__(self, filename): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.model.get_modelmeta().custom_metadata_map + print(meta) + + self.window_size = int(meta["window_size"]) + self.sample_rate = int(meta["sample_rate"]) + self.window_shift = int(0.1 * self.window_size) + self.receptive_field_size = int(meta["receptive_field_size"]) + self.receptive_field_shift = int(meta["receptive_field_shift"]) + self.num_speakers = int(meta["num_speakers"]) + self.powerset_max_classes = int(meta["powerset_max_classes"]) + self.num_classes = int(meta["num_classes"]) + + def __call__(self, x): + """ + Args: + x: (N, num_samples) + Returns: + A tensor of shape (N, num_frames, num_classes) + """ + x = np.expand_dims(x, axis=1) + + (y,) = self.model.run( + [self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x} + ) + + return y + + +def load_wav(filename, expected_sample_rate) -> np.ndarray: + audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + if sample_rate != expected_sample_rate: + audio = librosa.resample( + audio, + orig_sr=sample_rate, + target_sr=expected_sample_rate, + ) + return audio + + +def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes): + mapping = np.zeros((num_classes, num_speakers)) + + k = 1 + for i in range(1, powerset_max_classes + 1): + if i == 1: + for j in range(0, num_speakers): + mapping[k, j] = 1 + k += 1 + elif i == 2: + for j in range(0, num_speakers): + for m in range(j + 1, num_speakers): + mapping[k, j] = 1 + mapping[k, m] = 1 + k += 1 + elif i == 3: + raise RuntimeError("Unsupported") + + return mapping + + +def to_multi_label(y, mapping): + """ + Args: + y: (num_chunks, num_frames, num_classes) + Returns: + A tensor of shape (num_chunks, num_frames, num_speakers) + """ + y = np.argmax(y, axis=-1) + labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1) + return labels + + +# speaker count per frame +def speaker_count(labels, seg_m): + """ + Args: + labels: (num_chunks, num_frames, num_speakers) + seg_m: Segmentation model + Returns: + A integer array of shape (num_total_frames,) + """ + labels = labels.sum(axis=-1) + # Now labels: (num_chunks, num_frames) + + num_frames = ( + int( + (seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift) + / seg_m.receptive_field_shift + ) + + 1 + ) + ans = np.zeros((num_frames,)) + count = np.zeros((num_frames,)) + + for i in range(labels.shape[0]): + this_chunk = labels[i] + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) + end = start + this_chunk.shape[0] + ans[start:end] += this_chunk + count[start:end] += 1 + + ans /= np.maximum(count, 1e-12) + + return (ans + 0.5).astype(np.int8) + + +def load_speaker_embedding_model(filename): + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( + model=filename, + num_threads=1, + debug=0, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) + return extractor + + +def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap): + """ + Args: + embedding_filename: Path to the speaker embedding extractor model + audio: (num_samples,) + labels: (num_chunks, num_frames, num_speakers) + seg_m: segmentation model + Returns: + Return (num_chunks, num_speakers, embedding_dim) + """ + if exclude_overlap: + labels = labels * (labels.sum(axis=-1, keepdims=True) < 2) + + extractor = load_speaker_embedding_model(embedding_filename) + buffer = np.empty(seg_m.window_size) + num_chunks, num_frames, num_speakers = labels.shape + + ans_chunk_speaker_pair = [] + ans_embeddings = [] + + for i in range(num_chunks): + labels_T = labels[i].T + # t: (num_speakers, num_frames) + + sample_offset = i * seg_m.window_shift + + for j in range(num_speakers): + frames = labels_T[j] + if frames.sum() < 10: + # skip segment less than 20 frames, i.e., about 0.2 seconds + continue + + start = None + start_samples = 0 + idx = 0 + for k in range(num_frames): + if frames[k] != 0: + if start is None: + start = k + elif start is not None: + start_samples = ( + int(start / num_frames * seg_m.window_size) + sample_offset + ) + end_samples = ( + int(k / num_frames * seg_m.window_size) + sample_offset + ) + num_samples = end_samples - start_samples + buffer[idx : idx + num_samples] = audio[start_samples:end_samples] + idx += num_samples + + start = None + if start is not None: + start_samples = ( + int(start / num_frames * seg_m.window_size) + sample_offset + ) + end_samples = int(k / num_frames * seg_m.window_size) + sample_offset + num_samples = end_samples - start_samples + buffer[idx : idx + num_samples] = audio[start_samples:end_samples] + idx += num_samples + + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx]) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + + ans_chunk_speaker_pair.append([i, j]) + ans_embeddings.append(embedding) + + assert len(ans_chunk_speaker_pair) == len(ans_embeddings), ( + len(ans_chunk_speaker_pair), + len(ans_embeddings), + ) + return ans_chunk_speaker_pair, np.array(ans_embeddings) + + +def main(): + args = get_args() + assert Path(args.seg_model).is_file(), args.seg_model + assert Path(args.wav).is_file(), args.wav + + seg_m = OnnxSegmentationModel(args.seg_model) + audio = load_wav(args.wav, seg_m.sample_rate) + # audio: (num_samples,) + + num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1 + + samples = as_strided( + audio, + shape=(num, seg_m.window_size), + strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]), + ) + + # or use torch.Tensor.unfold + # samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy() + + if ( + audio.shape[0] < seg_m.window_size + or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0 + ): + has_last_chunk = True + else: + has_last_chunk = False + + num_chunks = samples.shape[0] + batch_size = 32 + output = [] + for i in range(0, num_chunks, batch_size): + start = i + end = i + batch_size + # it's perfectly ok to use end > num_chunks + y = seg_m(samples[start:end]) + output.append(y) + + if has_last_chunk: + last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa + pad_size = seg_m.window_size - last_chunk.shape[0] + last_chunk = np.pad(last_chunk, (0, pad_size)) + last_chunk = np.expand_dims(last_chunk, axis=0) + y = seg_m(last_chunk) + output.append(y) + + y = np.vstack(output) + # y: (num_chunks, num_frames, num_classes) + + mapping = get_powerset_mapping( + num_classes=seg_m.num_classes, + num_speakers=seg_m.num_speakers, + powerset_max_classes=seg_m.powerset_max_classes, + ) + labels = to_multi_label(y, mapping=mapping) + # labels: (num_chunks, num_frames, num_speakers) + + inactive = (labels.sum(axis=1) == 0).astype(np.int8) + # inactive: (num_chunks, num_speakers) + + speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m) + # speakers_per_frame: (num_frames, speakers_per_frame) + + if speakers_per_frame.max() == 0: + print("No speakers found in the audio file!") + return + + # if users specify only 1 speaker for clustering, then return the + # result directly + + # Now, get embeddings + chunk_speaker_pair, embeddings = get_embeddings( + args.speaker_embedding_model, + audio=audio, + labels=labels, + seg_m=seg_m, + # exclude_overlap=True, + exclude_overlap=False, + ) + # chunk_speaker_pair: a list of (chunk_idx, speaker_idx) + # embeddings: (batch_size, embedding_dim) + + # Please change num_clusters or threshold by yourself. + clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2) + # clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8) + clustering = sherpa_onnx.FastClustering(clustering_config) + cluster_labels = clustering(embeddings) + + chunk_speaker_to_cluster = dict() + for (chunk_idx, speaker_idx), cluster_idx in zip( + chunk_speaker_pair, cluster_labels + ): + if inactive[chunk_idx, speaker_idx] == 1: + print("skip ", chunk_idx, speaker_idx) + continue + chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx + + num_speakers = max(cluster_labels) + 1 + relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers)) + for i in range(labels.shape[0]): + for j in range(labels.shape[1]): + for k in range(labels.shape[2]): + if (i, k) not in chunk_speaker_to_cluster: + continue + t = chunk_speaker_to_cluster[(i, k)] + + if labels[i, j, k] == 1: + relabels[i, j, t] = 1 + + num_frames = ( + int( + (seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift) + / seg_m.receptive_field_shift + ) + + 1 + ) + + count = np.zeros((num_frames, relabels.shape[-1])) + for i in range(relabels.shape[0]): + this_chunk = relabels[i] + start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5) + end = start + this_chunk.shape[0] + count[start:end] += this_chunk + + if has_last_chunk: + stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift) + count = count[:stop_frame] + + sorted_count = np.argsort(-count, axis=-1) + final = np.zeros((count.shape[0], count.shape[1])) + + for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)): + for k in range(c): + final[i, sc[k]] = 1 + + min_duration_off = 0.5 + min_duration_on = 0.3 + onset = 0.5 + offset = 0.5 + # final: (num_frames, num_speakers) + + final = final.T + for kk in range(final.shape[0]): + segment_list = [] + frames = final[kk] + + is_active = frames[0] > onset + + start = None + if is_active: + start = 0 + scale = seg_m.receptive_field_shift / seg_m.sample_rate + scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5 + for i in range(1, len(frames)): + if is_active: + if frames[i] < offset: + segment = Segment( + start=start * scale + scale_offset, + end=i * scale + scale_offset, + speaker=kk, + ) + segment_list.append(segment) + is_active = False + else: + if frames[i] > onset: + start = i + is_active = True + + if is_active: + segment = Segment( + start=start * scale + scale_offset, + end=(len(frames) - 1) * scale + scale_offset, + speaker=kk, + ) + segment_list.append(segment) + + if len(segment_list) > 1: + merge_segment_list(segment_list, min_duration_off=min_duration_off) + for s in segment_list: + if s.duration < min_duration_on: + continue + print(s) + + +if __name__ == "__main__": + main() diff --git a/scripts/pyannote/segmentation/speaker-diarization-torch.py b/scripts/pyannote/segmentation/speaker-diarization-torch.py new file mode 100755 index 000000000..18a50ec08 --- /dev/null +++ b/scripts/pyannote/segmentation/speaker-diarization-torch.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +""" +Please refer to +https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml +for usages. +""" + +""" +1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main +wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx + +2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py + +``` + # self._embedding = PretrainedSpeakerEmbedding( + # self.embedding, use_auth_token=use_auth_token + # ) + self._embedding = embedding +``` +""" + +import argparse +from pathlib import Path + +import torch +from pyannote.audio import Model +from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline +from pyannote.audio.pipelines.speaker_verification import ( + ONNXWeSpeakerPretrainedSpeakerEmbedding, +) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") + + return parser.parse_args() + + +def build_pipeline(): + embedding_filename = "./speaker-embedding.onnx" + if Path(embedding_filename).is_file(): + # You need to modify line 166 + # of pyannote/audio/pipelines/speaker_diarization.py + # Please see the comments at the start of this script for details + embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename) + else: + embedding = "hbredin/wespeaker-voxceleb-resnet34-LM" + + pt_filename = "./pytorch_model.bin" + segmentation = Model.from_pretrained(pt_filename) + segmentation.eval() + + pipeline = SpeakerDiarizationPipeline( + segmentation=segmentation, + embedding=embedding, + embedding_exclude_overlap=True, + ) + + params = { + "clustering": { + "method": "centroid", + "min_cluster_size": 12, + "threshold": 0.7045654963945799, + }, + "segmentation": {"min_duration_off": 0.5}, + } + + pipeline.instantiate(params) + return pipeline + + +@torch.no_grad() +def main(): + args = get_args() + assert Path(args.wav).is_file(), args.wav + pipeline = build_pipeline() + print(pipeline) + t = pipeline(args.wav) + print(type(t)) + print(t) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/fast-clustering.cc b/sherpa-onnx/csrc/fast-clustering.cc index f479a707e..f6ac56a36 100644 --- a/sherpa-onnx/csrc/fast-clustering.cc +++ b/sherpa-onnx/csrc/fast-clustering.cc @@ -52,7 +52,7 @@ class FastClustering::Impl { std::vector height(num_rows - 1); fastclustercpp::hclust_fast(num_rows, distance.data(), - fastclustercpp::HCLUST_METHOD_SINGLE, + fastclustercpp::HCLUST_METHOD_COMPLETE, merge.data(), height.data()); std::vector labels(num_rows);