LMDeployによる最適化で高速になった「MiraTTS」をWindowsで推論(測度計測)する

初めに

高速に推論ができるらしい MiraTTSを触ってみます。

uvでWindowsに対応したリポジトリは以下で公開をしています

github.com

開発環境

項目 バージョン
OS Windows 11
CUDA 12.x (v13.0も動作確認済み)
Python 3.11
パッケージマネージャ uv

環境構築

Windowsで環境構築をする場合はnvidia-nccl-cu12がWindows非対応(Linuxのみ)のため、注意が必要です.

リポジトリをクローンします

git clone https://github.com/ayutaz/MiraTTS.git
cd MiraTTS

pyproject.tomlは以下のように定義しています

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "FastNeuTTS"
version = "0.0.11"
authors = [
  { name="Yatharth Sharma", email="yatharthsharma3501@gmail.com" },
]
description = "High quality and Fast TTS with MiraTTS"
readme = "README.md"
requires-python = ">=3.10"
classifiers = [
    "Programming Language :: Python :: 3",
    "License :: OSI Approved :: MIT License",
    "Operating System :: OS Independent",
]
dependencies = [
    "torch>=2.0.0",
    "torchaudio>=2.0.0",
    "torchvision>=0.15.0",
    "lmdeploy",
    "librosa",
    "fastaudiosr @ git+https://github.com/ysharma3501/FlashSR.git",
    "ncodec @ git+https://github.com/ysharma3501/FastBiCodec.git",
    "einops",
    "onnxruntime-gpu",
    "omegaconf>=2.3.0",
]

[project.urls]
Homepage = "https://github.com/ysharma3501/MiraTTS"
Issues = "https://github.com/ysharma3501/MiraTTS/issues"

[tool.uv]
override-dependencies = [
    "nvidia-nccl-cu12 ; sys_platform == 'linux'",
]

[tool.uv.sources]
torch = [
    { index = "pytorch-cu124", marker = "sys_platform == 'win32'" },
    { index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
]
torchaudio = [
    { index = "pytorch-cu124", marker = "sys_platform == 'win32'" },
    { index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
]
torchvision = [
    { index = "pytorch-cu124", marker = "sys_platform == 'win32'" },
    { index = "pytorch-cu124", marker = "sys_platform == 'linux'" },
]

[[tool.uv.index]]
name = "pytorch-cu124"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

依存関係をインストールします

uv sync

推論

実際に以下のテキストで推論しました

[
      ("Short", "Hello, this is a test of the MiraTTS text to speech system."),
      ("Medium", "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet."),
      ("Long", "Artificial intelligence is transforming the way we interact with technology, making it more natural and intuitive than ever before."),
]

二回目以降はKVキャッシュで高速化されるため以下のような結果になりました。バッチ等いくつか高速にしないと RTFが0.1にならなさそうです

  | テキスト        | 推論時間      |
  |-----------------|---------------|
  | Short (59文字)  | 0.951s (4.6x) | 
  | Medium (97文字) | 1.306s (4.7x) | 
  | Long (131文字)  | 1.883s (4.4x) |