wespeakerとxvectorの話者埋め込みモデルを使った日本語話者ダイアライゼーションの評価

初めに

音声データを文字お越しをする際に、複数人の音声が入っている場合に 「誰がいつ話したのか」を推定する技術として 話者ダイアライゼーションがあります。今回は、日本語音声において いくつかのモデルを使って比較をしていきます

過去に各モデルを動かす記事は書いていますので、参考に見てください

ayousanz.hatenadiary.jp

ayousanz.hatenadiary.jp

ayousanz.hatenadiary.jp

ayousanz.hatenadiary.jp

事前調査

話者ダイアライゼーションでよく使われるものは、pyannote-audioがあります。こちらは少し前に pyannote/speaker-diarization-3.1がリリースされています。この埋め込みモデルには wespeaker-voxceleb-resnet34-LMが使われているため、元になっている wespeakerのオリジナルモデルで評価を行っています

開発環境

評価データ

今回の評価データには CABank Japanese CallHome Corpusを使用しています。

このデータセットは huggingfaceのtalkbank/callhomeからダウンロードすることができます。

今回評価をするために、元のデータセットからwavデータに変換してメタデータjson化したデータセットを以下に公開しています

huggingface.co

対象のモデル

今回は以下のモデルを対象に評価を行いました

モデル名 URL
wespeaker-cnceleb-resnet34-LM ダウンロードリンク
wespeaker-voxceleb-resnet152-LM ダウンロードリンク
wespeaker-voxceleb-resnet293-LM ダウンロードリンク
wespeaker-voxceleb-resnet34-LM ダウンロードリンク
xvector_jtubespeech ダウンロードリンク

評価結果

それぞれのモデルの評価結果は以下になります。詳細はリンク先の評価データをご確認ください。

モデル名 平均DER 評価結果のリンク先
wespeaker-cnceleb-resnet34-LM 46.99% wespeaker-cnceleb-resnet34-LM-result.txt
wespeaker-voxceleb-resnet152-LM 38.72% wespeaker-voxceleb-resnet152-LM_results.txt
wespeaker-voxceleb-resnet293-LM 39.02% wespeaker-voxceleb-resnet293-LM_results.txt
wespeaker-voxceleb-resnet34-LM 39.27% wespeaker-voxceleb-resnet34-LM-result.txt
xvector_jtubespeech 48.52% xvector_jtubespeech-der-umap_results.txt

評価方法

それぞれのモデルの評価方法およびそのコードは以下になります。

フォルダ構成は以下にようになっています。

- **root/**
  - **callhome_japanese_audio/**
    - callhome_jpn_0.wav
    - callhome_jpn_1.wav
    - callhome_jpn_2.wav
    - ...
    - callhome_jpn_99.wav
    - wav.scp
  - **wespeaker-cnceleb-resnet34-LM/**
  - **wespeaker-voxceleb-resnet152-LM/**
  - **wespeaker-voxceleb-resnet293-LM/**
  - **wespeaker-voxceleb-resnet34-LM/**
  - callhome_japanese_metadata.json
  - callhome_japanese.rttm
  - predicted.rttm
  - README.md
  - requirements.txt
  - umap_clusterer.py
  - wespeacker-test.py
  - wespeaker-cnceleb-resnet34-LM-result.txt
  - wespeaker-voxceleb-resnet152-LM_results.txt
  - wespeaker-voxceleb-resnet293-LM_results.txt
  - wespeaker-voxceleb-resnet34-LM-result.txt
  - x-vector-umap-test.py
  - xvector_jtubespeech-der-umap_results.txt

以下に実際に評価を行ったRepositoryを公開しています(詳細の結果もこちらにあげています)

github.com

wespeaker

評価の流れとしては以下になります

  1. WeSpeakerモデルをロードし、デバイス(CPU/GPU)を設定する。
  2. 音声ファイルのリストを作成して準備する。
  3. 各音声ファイルに対してmodel.diarizeで話者ダイアライゼーションを実行する。
  4. ダイアライゼーション結果をRTTM形式のファイルに保存する。
  5. リファレンスのRTTMファイルを用意する(必要ならメタデータから作成)。
  6. リファレンスと予測結果のRTTMファイルを読み込み、評価用アノテーションを作成する。
  7. DiarizationErrorRateの評価指標を初期化する。
  8. 各音声ファイルについてDERを計算し、結果を記録する。
  9. 全体のDERを計算し、評価結果をファイルに保存する。

評価コードは以下になります

import os
import json
import torch
import wespeaker
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Annotation, Segment

# モデルのパスを指定
model_dir = "wespeaker-voxceleb-resnet152-LM"
model = wespeaker.load_model_local(model_dir)

# 必要に応じてデバイスを設定
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.set_device(device)

print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available. Using CPU.")

# 音声ファイルのディレクトリ
audio_dir = 'callhome_japanese_audio'

# 音声ファイルのリストを作成
audio_files = [
    os.path.join(audio_dir, f)
    for f in os.listdir(audio_dir)
    if f.endswith('.wav') or f.endswith('.mp3')
]

# 結果を格納するリスト
all_results = []

for audio_file in audio_files:
    utt_id = os.path.splitext(os.path.basename(audio_file))[0]
    print(f"Processing {utt_id}...")
    diarization_result = model.diarize(audio_file)
    all_results.append((utt_id, diarization_result))

# 予測結果のRTTMファイルの作成
hypothesis_rttm = 'predicted.rttm'
with open(hypothesis_rttm, 'w', encoding='utf-8') as f:
    for utt_id, result in all_results:
        for segment in result:
            # segment: [開始時間, 終了時間, 話者ラベル]
            start_time = float(segment[1])
            end_time = float(segment[2])
            speaker_label = segment[3]
            duration = end_time - start_time
            f.write(f"SPEAKER {utt_id} 1 {start_time:.3f} {duration:.3f} <NA> <NA> {speaker_label} <NA> <NA>\n")

# リファレンスのRTTMファイルを作成(既に作成済みの場合はこの部分をコメントアウト可能)
json_input_path = 'callhome_japanese_metadata.json'
reference_rttm = 'callhome_japanese.rttm'
with open(json_input_path, 'r', encoding='utf-8') as f:
    metadata_list = json.load(f)
with open(reference_rttm, 'w', encoding='utf-8') as f_rttm:
    for metadata in metadata_list:
        audio_filename = metadata['audio_filename']
        uri = os.path.splitext(audio_filename)[0]
        utterances = metadata['utterances']
        for utt in utterances:
            start_time = utt['start_time']
            end_time = utt['end_time']
            duration = end_time - start_time
            speaker = utt['speaker']
            f_rttm.write(f"SPEAKER {uri} 1 {start_time:.3f} {duration:.3f} <NA> <NA> {speaker} <NA> <NA>\n")

# リファレンスと予測結果のRTTMファイルを読み込み
def load_rttm(file_path):
    annotations = {}
    with open(file_path, 'r') as f:
        for line in f:
            tokens = line.strip().split()
            uri = tokens[1]
            start_time = float(tokens[3])
            duration = float(tokens[4])
            end_time = start_time + duration
            speaker = tokens[7]
            segment = Segment(start_time, end_time)
            if uri not in annotations:
                annotations[uri] = Annotation(uri=uri)
            annotations[uri][segment] = speaker
    return annotations

reference = load_rttm(reference_rttm)
hypothesis = load_rttm(hypothesis_rttm)

metric = DiarizationErrorRate()

## 出力を保存するテキストファイルを開く
output_file = 'wespeaker-voxceleb-resnet152-LM_results.txt'
with open(output_file, 'w', encoding='utf-8') as result_f:

    # 各ファイルごとに評価
    for utt_id in reference:
        ref = reference[utt_id]
        hyp = hypothesis.get(utt_id, None)

        if hyp is None:
            result_line = f"Hypothesis for {utt_id} not found."
            
            print(result_line)
            result_f.write(result_line + '\n')
            continue

        der = metric(ref, hyp)
        result_line = f"{utt_id}: DER = {der * 100:.2f}%"
        print(result_line)
        result_f.write(result_line + '\n')

    # 全体のDERを計算
    total_der = abs(metric)
    total_result_line = f"Total DER: {total_der * 100:.2f}%"
    print(total_result_line)
    result_f.write(total_result_line + '\n')

# 結果が 'der_results.txt' に保存されます
print(f"Results saved to {output_file}")

xvector_jtubespeech

wespeakerと評価方法を揃えるために、UMAPによる次元削減およびHDBSCANによるクラスタリングを行っています

  1. x-vectorモデルをロードし、デバイス(CPU/GPU)を設定する。
  2. 処理する音声ファイルのリストを準備する。
  3. Silero VADモデルをロードして音声区間検出を準備する。
  4. 各音声ファイルに対してVADで音声区間を検出する。
  5. 検出された各音声区間からMFCCとx-vector埋め込みを抽出する。
  6. 埋め込みベクトルと対応するセグメントを収集する。
  7. UMAPで埋め込みベクトルの次元削減を行う。
  8. HDBSCANで次元削減後のベクトルをクラスタリングし、話者ラベルを割り当てる。
  9. 必要に応じてPAHCでクラスタを精錬する。
  10. ダイアライゼーション結果をRTTM形式のファイルに書き出す。
  11. リファレンスのRTTMファイルを読み込み、評価用アノテーションを作成する。
  12. DiarizationErrorRate評価指標を初期化する。
  13. 各音声ファイルについてDERを計算し、結果を記録する。
  14. 全体のDERを計算し、評価結果をファイルに保存する。
  15. 結果を出力して評価を完了する。

評価コードは以下になります

# スクリプトの最初に環境変数を設定
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["NUMBA_NUM_THREADS"] = "1"

import json
import numpy as np
import torch
import torchaudio
from scipy.io import wavfile
from torchaudio.compliance import kaldi
from pyannote.metrics.diarization import DiarizationErrorRate
from pyannote.core import Annotation, Segment
# UMAPとHDBSCANをインポート
import umap
import hdbscan
# 必要に応じてPAHCクラスをインポートまたは定義
from wespeaker.diar.umap_clusterer import PAHC  # PAHCクラスを別途コピーして使用

from xvector_jtubespeech import XVector
from tqdm import tqdm

# 1. x-vectorモデルのロード
model = torch.hub.load("sarulab-speech/xvector_jtubespeech", "xvector", trust_repo=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"CUDA Device Name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available. Using CPU.")

# 2. 音声ファイルのディレクトリ
audio_dir = 'callhome_japanese_audio'

# 音声ファイルのリストを作成
audio_files = [
    os.path.join(audio_dir, f)
    for f in os.listdir(audio_dir)
    if f.endswith('.wav') or f.endswith('.mp3')
]

# 音声区間検出(VAD)の準備
vad_model, utils = torch.hub.load(
    repo_or_dir='snakers4/silero-vad',
    model='silero_vad',
    force_reload=False
)
(get_speech_timestamps, _, _, _, _) = utils

# VADモデルはCPU上にあることを確認
vad_model.eval()

# 予測結果のRTTMファイルを作成するためのリスト
all_results = []

for audio_file in audio_files:
    utt_id = os.path.splitext(os.path.basename(audio_file))[0]
    print(f"Processing {utt_id}...")

    # 2.1 音声の読み込み(wav は CPU 上)
    wav, sr = torchaudio.load(audio_file)
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        wav = resampler(wav)
        sr = 16000

    # 2.2 音声区間検出(VAD)の適用
    speech_timestamps = get_speech_timestamps(wav.squeeze(0), vad_model, sampling_rate=sr)
    if not speech_timestamps:
        print(f"No speech detected in {utt_id}.")
        continue

    # 3. セグメントごとにx-vectorを抽出
    embeddings = []
    segments = []
    for ts in speech_timestamps:
        start_frame = ts['start']
        end_frame = ts['end']
        segment_wav = wav[:, start_frame:end_frame]
        segment_wav_np = segment_wav.numpy().squeeze(0)

        # 3.1 MFCCの抽出
        segment_tensor = torch.from_numpy(segment_wav_np.astype(np.float32)).unsqueeze(0).to(device)
        mfcc = kaldi.mfcc(segment_tensor, num_ceps=24, num_mel_bins=24).unsqueeze(0)

        # 3.2 x-vectorの抽出
        with torch.no_grad():
            xvector = model.vectorize(mfcc)
        xvector = xvector.cpu().numpy()[0]

        embeddings.append(xvector)
        # 時間を秒に変換
        start_time = start_frame / sr
        end_time = end_frame / sr
        segments.append((start_time, end_time))

    embeddings = np.array(embeddings)

    # 4. UMAPによる次元削減
    if len(embeddings) <= 2:
        labels = [0] * len(embeddings)
    else:
        umap_embeddings = umap.UMAP(
            n_components=min(32, len(embeddings) - 2),
            metric='cosine',
            n_neighbors=16,  # 必要に応じて調整
            min_dist=0.05,   # 必要に応じて調整
            random_state=2023,
            n_jobs=1
        ).fit_transform(embeddings)

        # 5. HDBSCANによるクラスタリング
        labels = hdbscan.HDBSCAN(
            allow_single_cluster=True,
            min_cluster_size=4,
            approx_min_span_tree=False,
            core_dist_n_jobs=1
        ).fit_predict(umap_embeddings)

        # 6. PAHCによるクラスタのマージと吸収
        labels = PAHC(
            merge_cutoff=0.3,
            min_cluster_size=3,
            absorb_cutoff=0.0
        ).fit_predict(labels, embeddings)

    # 予測結果を保存
    diarization_result = []
    for (segment, label) in zip(segments, labels):
        diarization_result.append([utt_id, segment[0], segment[1], label])
    all_results.extend(diarization_result)

# 7. 予測結果のRTTMファイルの作成
hypothesis_rttm = 'predicted.rttm'
with open(hypothesis_rttm, 'w', encoding='utf-8') as f:
    for entry in all_results:
        utt_id, start_time, end_time, speaker_label = entry
        duration = end_time - start_time
        f.write(f"SPEAKER {utt_id} 1 {start_time:.3f} {duration:.3f} <NA> <NA> speaker_{speaker_label} <NA> <NA>\n")

# 以下、評価コード(リファレンスのRTTMファイルの読み込みなど)を追加

# 8. リファレンスのRTTMファイルを読み込み(ご自身のコードに合わせてください)
# 例えば:
reference_rttm = 'callhome_japanese.rttm'
def load_rttm(file_path):
    annotations = {}
    with open(file_path, 'r') as f:
        for line in f:
            tokens = line.strip().split()
            uri = tokens[1]
            start_time = float(tokens[3])
            duration = float(tokens[4])
            end_time = start_time + duration
            speaker = tokens[7]
            segment = Segment(start_time, end_time)
            if uri not in annotations:
                annotations[uri] = Annotation(uri=uri)
            annotations[uri][segment] = speaker
    return annotations

reference = load_rttm(reference_rttm)
hypothesis = load_rttm(hypothesis_rttm)

metric = DiarizationErrorRate()

# 9. 評価結果を保存するテキストファイルを開く
output_file = 'xvector_jtubespeech-der-umap_results.txt'
with open(output_file, 'w', encoding='utf-8') as result_f:

    # 各ファイルごとに評価
    for utt_id in reference:
        ref = reference[utt_id]
        hyp = hypothesis.get(utt_id, None)

        if hyp is None:
            result_line = f"Hypothesis for {utt_id} not found."
            print(result_line)
            result_f.write(result_line + '\n')
            continue

        der = metric(ref, hyp)
        result_line = f"{utt_id}: DER = {der * 100:.2f}%"
        print(result_line)
        result_f.write(result_line + '\n')

    # 全体のDERを計算
    total_der = abs(metric)
    total_result_line = f"Total DER: {total_der * 100:.2f}%"
    print(total_result_line)
    result_f.write(total_result_line + '\n')

print(f"Results saved to {output_file}")