Style-Bert-VITS2(SBV2)でAssertionError: choose a window size 400 that is [2, 251]の対応方法

初めに

以下の学習にて以下のエラーにて前処理が止まることがあります。こちらの対応をしていきます

packages/torchaudio/compliance/kaldi.py", line 142, in _get_waveform_and_window_properties
    assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
AssertionError: choose a window size 400 that is [2, 251]

11-15 11:16:09 | ERROR  | train.py:246 | Step 5: style_gen failed.

github.com

開発環境

  • Python 3.10.12
  • Ubunts 22.02
  • cuda 12.1

ライブラリ ver

numpy                    1.26.4
pyannote-audio           3.3.2
pyannote-core            5.0.0
pyannote-database        5.1.0
pyannote-metrics         3.2.1
pyannote-pipeline        3.0.1
pycparser                2.22
pyopenjtalk-dict         0.3.4.dev2
pytorch-lightning        2.4.0
pytorch-metric-learning  2.7.0
torch                    2.3.0+cu121
torch-audiomentations    0.11.1
torch-pitch-shift        1.2.5
torchaudio               2.3.0+cu121
torchmetrics             1.5.1
transformers             4.39.3
triton                   2.3.0

原因

原因としてはpyannote.audioライブラリ内での音声波形の長さと、信号処理で使用されるウィンドウサイズとの不一致から生じています

  • window_sizeが400サンプルであるのに対し、波形の長さが251サンプルしかないことを示しています
  • kaldi.py内のアサーションは、window_sizeが2サンプル以上で、かつ波形の長さ以下である必要があるとチェックしています

対応方法

以下の関数を save_style_vectorの初めに 追加して、波形の長さが一定以下であれば処理をしないようにします

import torchaudio

def is_audio_length_valid(wav_path: str, frame_length_ms: float = 25.0) -> bool:
    waveform, sample_rate = torchaudio.load(wav_path)
    min_length_samples = int(sample_rate * frame_length_ms / 1000.0)
    if waveform.shape[1] < min_length_samples:
        logger.warning(f"Skipping {wav_path}: audio too short ({waveform.shape[1]} samples)")
        return False
    return True
def save_style_vector(wav_path: str):
    if not is_audio_length_valid(wav_path):
        raise ValueError(f"Audio too short: {wav_path}")
    try:
        style_vec = get_style_vector(wav_path)
    except Exception as e:

またエラーログの解析用に 以下を変更しておくと便利です

def process_line(line: str):
    wav_path = line.split("|")[0]
    try:
        save_style_vector(wav_path)
        return line, None
    except NaNValueError:
        return line, "nan_error"
    except ValueError as e:
        if "Audio too short" in str(e):
            return line, "short_audio"
        else:
            logger.error(f"Unexpected error for {wav_path}: {e}")
            return line, "other_error"
    except Exception as e:
        logger.error(f"Unexpected error for {wav_path}: {e}")
        return line, "other_error"