複数のLLMのPerplexityの精度を比較して、文章の自然さを判定を試す

初めに

環境

  • L4 GPU
  • ubuntu22.04

準備

ライブラリをインストールします

pip install torch transformers huggingface_hub

比較対象のモデル

  • stabilityai/StableBeluga-7B
  • mistralai/Mistral-7B-Instruct-v0.2
  • Rakuten/RakutenAI-7B-chat

対象のデータ

今回の対象のデータは yodasのja000の一部を使用します。

   Text
1   それと僕が材料をお伝えした時にバニラエッセンスを入れたじゃないですか
2   1弦の5フレット、2弦の5フレット、 3弦の5フレット、2弦の7フレット、
3   けどもこれでえっと木スキル使うとさらに カウンターが1個貯まる
4   長い年月をかけて韓国人朝鮮人 と向き合ってきた中国人は韓国人
5   ごいハマり始めて
6   50話 いらっしゃいませ♪ ヘラのグランプリ!
7   オムニテクでは人々を助けを上がっている。
8   ステーキってもんはね
9   これが聞こえてきた
10  なるほどねそうなんだやばいすげ えわ
11  殺菌得ながらを
12  じゃあ呼吸法やりましょう、呼吸、呼吸
13  カンタ : そこやれるの見たいと思ってるよ、みんな
14  すげーなんか色んなものがこう そうなんだそうそうばすごいね
15  前回は standard assets を使って 遊んでいました
16  を獲得しました
17  しかしながら本当の理由はそれ だけではないのです
18  病むなんですが
19  全体混ぜると これね混ぜたら数分蒸らした方がいいです
20  アプリはそういったことを解決して くれるアプリです
21  ご丁寧に
22  下に
23  合わせてかぶせれば...
24  そういう言い方で
25  朝食
26  流行ったものだそうですが
27  こうなります
28  次は釣りレベルの上限解放です。今度は何を釣らされるんだろうと思っていましたが、タマカイでした。
29  守りたい
30  だってさっきの人从众…

複数のモデルでPerplexityの値を取得

以下のコードでそれぞれのLLMのPerplexityの値を取得する

import json
import csv
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def perplexity(model, tokenizer, text) -> torch.Tensor:
    tokenized_input = tokenizer.encode(
        text, add_special_tokens=False, return_tensors="pt"
    ).to(model.device)

    with torch.inference_mode():
        output = model(tokenized_input, labels=tokenized_input)
        ppl = torch.exp(output.loss)

    return ppl.item()

models = [
    "stabilityai/StableBeluga-7B",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "Rakuten/RakutenAI-7B-chat"
]

results = {}
csv_data = [["Text"] + [f"{model}_Perplexity" for model in models]]

with open("testData.txt", "r", encoding="utf-8") as f:
    lines = [line.strip() for line in f.readlines()[:30]]

for line in lines:
    line_results = {"text": line}
    csv_line = [line]

    for model_name in models:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="cuda", torch_dtype=torch.float16
        )

        ppl = perplexity(model, tokenizer, line)

        line_results[model_name] = {
            "perplexity": ppl
        }
        csv_line.append(ppl)

        del model
        torch.cuda.empty_cache()

    results[line] = line_results

with open("perplexity_results5.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=4)

with open("perplexity_results5.csv", "w", newline="", encoding="utf-8") as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(csv_data[0])
    for line in lines:
        csv_writer.writerow([line] + [results[line][model]["perplexity"] for model in models])

結果

Text stabilityai/StableBeluga-7B_Perplexity mistralai/Mistral-7B-Instruct-v0.2_Perplexity Rakuten/RakutenAI-7B-chat_Perplexity
それと僕が材料をお伝えした時にバニラエッセンスを入れたじゃないですか 13.036222457885742 30.592239379882812 77.44200134277344
1弦の5フレット、2弦の5フレット、 3弦の5フレット、2弦の7フレット、 4.010903358459473 8.111798286437988 6.279117107391357
けどもこれでえっと木スキル使うとさらに カウンターが1個貯まる 43.25908279418945 83.80355072021484 470.63653564453125
長い年月をかけて韓国人朝鮮人 と向き合ってきた中国人は韓国人 9.886605262756348 30.079328536987305 87.20085144042969
ごいハマり始めて 293.4146423339844 4317.986328125 334.7664794921875
50話 いらっしゃいませ♪ ヘラのグランプリ! 15.717564582824707 58.66378402709961 111.03834533691406
オムニテクでは人々を助けを上がっている。 65.81471252441406 90.47702026367188 1869.9520263671875
ステーキってもんはね 127.35417938232422 1428.76318359375 1153.15966796875
これが聞こえてきた 14.765314102172852 872.0487670898438 120.8353500366211
なるほどねそうなんだやばいすげ えわ 117.2963638305664 360.88482666015625 478.8385009765625
殺菌得ながらを 121.11297607421875 1618.6380615234375 3937.654052734375
じゃあ呼吸法やりましょう、呼吸、呼吸 9.24917221069336 33.77603530883789 262.35357666015625
カンタ : そこやれるの見たいと思ってるよ、みんな 51.29374694824219 86.22332000732422 313.9569091796875
すげーなんか色んなものがこう そうなんだそうそうばすごいね 37.839759826660156 103.7863540649414 135.54092407226562
前回は standard assets を使って 遊んでいました 33.08415985107422 90.7002182006836 162.7081298828125
を獲得しました 10.945358276367188 405.0987548828125 819.6907958984375
しかしながら本当の理由はそれ だけではないのです 10.608744621276855 28.476654052734375 130.30300903320312
病むなんですが 57.94710922241211 5614.7802734375 7005.65576171875
全体混ぜると これね混ぜたら数分蒸らした方がいいです 21.468355178833008 77.3314437866211 306.8896179199219
アプリはそういったことを解決して くれるアプリです 11.473381042480469 35.310543060302734 87.94023132324219
ご丁寧に 64.8609619140625 768.6679077148438 146.9370880126953
下に 8333.6044921875 385289856.0 118921.09375
合わせてかぶせれば... 79.0224838256836 315.1685485839844 1221.1976318359375
そういう言い方で 27.01752281188965 314.7711486816406 442.60284423828125
朝食 21129.865234375 302320000.0 658.495849609375
流行ったものだそうですが 33.65590286254883 241.5250701904297 1139.429443359375
こうなります 68.17607116699219 33799.25390625 943.4025268554688
次は釣りレベルの上限解放です。今度は何を釣らされるんだろうと思っていましたが、タマカイでした。 14.19534683227539 21.08135986328125 23.427536010742188
守りたい 164.52532958984375 41213.8125 217.5680389404297
だってさっきの人从众… 254.49359130859375 1318.85205078125 2268.48974609375