VITS2モデルの構造をモデルとconfig.jsonをロードして確認する

開発環境

モデルの構造確認

以下のコードでモデルの構造を確認できます

import torch
from pathlib import Path
import json
from safetensors import safe_open
import argparse


def load_model(model_path):
    if model_path.suffix == '.safetensors':
        with safe_open(model_path, framework="pt", device="cpu") as f:
            return {key: f.get_tensor(key) for key in f.keys()}
    elif model_path.suffix == '.pth':
        return torch.load(model_path, map_location='cpu')
    else:
        raise ValueError(f"Unsupported file format: {model_path.suffix}")


def analyze_model_structure(model_dict):
    sizes = set()
    important_shapes = {}
    for name, param in model_dict.items():
        if isinstance(param, torch.Tensor):
            if len(param.shape) > 0:
                sizes.add(param.shape[-1])
                if param.shape[-1] in [256, 512]:
                    important_shapes[name] = param.shape
    return sizes, important_shapes


def analyze_models(model_paths, config_path):
    config_path = Path(config_path)

    # configファイルを読み込む
    with open(config_path, 'r') as f:
        config = json.load(f)

    all_sizes = set()
    all_important_shapes = {}

    for model_path in model_paths:
        model_path = Path(model_path)
        print(f"\nAnalyzing {model_path.name}:")

        model_dict = load_model(model_path)
        if 'model' in model_dict:
            model_dict = model_dict['model']

        sizes, important_shapes = analyze_model_structure(model_dict)
        all_sizes.update(sizes)
        all_important_shapes.update(important_shapes)

        print(f"Unique sizes found: {sorted(sizes)}")
        print("Important shapes (256 or 512):")
        for name, shape in important_shapes.items():
            print(f"  {name}: shape = {shape}")

    print("\nOverall summary:")
    print(f"All unique sizes found across models: {sorted(all_sizes)}")

    print("\nImportant config information:")
    print(f"Model name: {config.get('model_name', 'Not specified')}")
    print(f"Version: {config.get('version', 'Not specified')}")
    print(f"Gin channels: {config['model'].get('gin_channels', 'Not specified')}")
    print(f"Hidden channels: {config['model'].get('hidden_channels', 'Not specified')}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze VITS2 model files")
    parser.add_argument("config_path", type=str, help="Path to the config.json file")
    parser.add_argument("model_paths", type=str, nargs='+', help="Paths to the model files (.pth or .safetensors)")

    args = parser.parse_args()

    analyze_models(args.model_paths, args.config_path)

以下のようにモデルの構造を確認することができます

python .\analyze_vits2_model_structure.py .\model_assets\test\config.json  .\model_assets\test\test_e1000_s25000.safetensors

Analyzing test_e1000_s25000.safetensors:
Unique sizes found: [1, 2, 3, 5, 7, 8, 11, 16, 29, 32, 64, 96, 128, 192, 256, 384, 512, 768]
Important shapes (256 or 512):
  dec.cond.bias: shape = torch.Size([512])
  dec.conv_pre.bias: shape = torch.Size([512])
  dec.resblocks.0.convs1.0.bias: shape = torch.Size([256])
  dec.resblocks.0.convs1.1.bias: shape = torch.Size([256])
  dec.resblocks.0.convs1.2.bias: shape = torch.Size([256])
  dec.resblocks.0.convs2.0.bias: shape = torch.Size([256])
  dec.resblocks.0.convs2.1.bias: shape = torch.Size([256])
  dec.resblocks.0.convs2.2.bias: shape = torch.Size([256])
  dec.resblocks.1.convs1.0.bias: shape = torch.Size([256])
  dec.resblocks.1.convs1.1.bias: shape = torch.Size([256])
  dec.resblocks.1.convs1.2.bias: shape = torch.Size([256])
  dec.resblocks.1.convs2.0.bias: shape = torch.Size([256])
  dec.resblocks.1.convs2.1.bias: shape = torch.Size([256])
  dec.resblocks.1.convs2.2.bias: shape = torch.Size([256])
  dec.resblocks.2.convs1.0.bias: shape = torch.Size([256])
  dec.resblocks.2.convs1.1.bias: shape = torch.Size([256])
  dec.resblocks.2.convs1.2.bias: shape = torch.Size([256])
  dec.resblocks.2.convs2.0.bias: shape = torch.Size([256])
  dec.resblocks.2.convs2.1.bias: shape = torch.Size([256])
  dec.resblocks.2.convs2.2.bias: shape = torch.Size([256])
  dec.ups.0.bias: shape = torch.Size([256])
  dp.conv_1.bias: shape = torch.Size([256])
  dp.conv_2.bias: shape = torch.Size([256])
  dp.norm_1.beta: shape = torch.Size([256])
  dp.norm_1.gamma: shape = torch.Size([256])
  dp.norm_2.beta: shape = torch.Size([256])
  dp.norm_2.gamma: shape = torch.Size([256])
  emb_g.weight: shape = torch.Size([1, 512])
  enc_p.encoder.spk_emb_linear.weight: shape = torch.Size([192, 512])
  enc_p.style_proj.weight: shape = torch.Size([192, 256])
  flow.flows.0.enc.spk_emb_linear.weight: shape = torch.Size([192, 512])
  flow.flows.2.enc.spk_emb_linear.weight: shape = torch.Size([192, 512])
  flow.flows.4.enc.spk_emb_linear.weight: shape = torch.Size([192, 512])
  flow.flows.6.enc.spk_emb_linear.weight: shape = torch.Size([192, 512])

Overall summary:
All unique sizes found across models: [1, 2, 3, 5, 7, 8, 11, 16, 29, 32, 64, 96, 128, 192, 256, 384, 512, 768]

Important config information:
Model name: nadeko
Version: 2.4.1-JP-Extra
Gin channels: 512
Hidden channels: 192