「潤羽るしあ」のツイートをrinna/japanese-gpt-neox-3.6bでfine-tuningして、ツイートの続きを考えてもらう【LLM】

初めに

松xRさんのツイートを使ったfine turningのツイートを見て、キャラクターの再現ができると思い今はなき「潤羽るしあ」のツイートを学習して、ツイートを作れるようにしていきます。

またこちらの記事は、以下の記事を参考にして学習を行っています

note.com

環境

fine turning

データの準備

学習するためには、データが必要なため、ツイートデータを集めて以下のようなフォーマットに変換します。

    {
        "input": "エゴサしてると、",
        "completion": "お誕生日ライブのことやオリ曲のこと、お歌やダンスのことたくさん呟いてくれてみんなありがとう( ´•̥ ·̭ •̥` )♡うれしすぎる。。。"
    },
    {
        "input": "主様、",
        "completion": "おきるにゃ〜🐾 https://t.co/uX28fv6bqT"
    },

この時に基本参考記事を同じことをしていますが、メンション(ユーザー名)を除くために以下の関数を使って使って整形処理を行っています。

# '@hoge' 形式のユーザーメンションを削除する関数
def remove_mentions(text):
    return re.sub(r'@\w+\s*', '', text)

以下がツイートデータをfine tuneingで使用するための整形コードです

import json
import re

# JSONファイルを読み込む
with open('correct_tweet_archive.json', 'r', encoding='utf-8') as f:
    data = json.load(f)

# 出力を保存するリストを作成
output = []

# '@hoge' 形式のユーザーメンションを削除する関数
def remove_mentions(text):
    return re.sub(r'@\w+\s*', '', text)

tweet_count = 0

# 各ツイートについて処理を行う
for tweet in data:
    full_text = tweet['data']['text']

    # 最初の「、」または「。」で区切る
    match = re.search(r'(.*?[、。])(.*)', full_text)

    # 「、」または「。」が見つかった場合
    if match:
        first_sentence = remove_mentions(match.group(1))
        remaining_text = remove_mentions(match.group(2))

        # 新しいフォーマットに整形
        formatted = {
            "input": first_sentence,
            "completion": remaining_text
        }

        # 結果をリストに追加
        tweet_count += 1
        print(formatted)
        output.append(formatted)

print(f"{tweet_count} 件のツイートを処理しました。")

# 結果を新しいJSONファイルに書き込む
with open('formatted_rushia_tweets.json', 'w', encoding='utf-8') as f:
    json.dump(output, f, indent=4, ensure_ascii=False)

print("新しいJSONファイルに書き込みました。")

学習

学習自体は、記事の通りなのですがコードだけ記載しておきます。

必要なライブラリのインストール*

!pip install -Uqq git+https://github.com/huggingface/peft.git
!pip install -Uqq transformers datasets accelerate
!pip install sentencepiece
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install scipy

学習コード

# 基本パラメータ

model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
peft_name = "lorappo-rinna-3.6b"
output_dir = "lorappo-rinna-3.6b-rushia-results1"

from transformers import AutoTokenizer

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

CUTOFF_LEN = 512  # コンテキスト長

# トークナイズ関数
def tokenize(prompt, tokenizer):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=CUTOFF_LEN,
        padding=False,
    )
    return {
        "input_ids": result["input_ids"],
        "attention_mask": result["attention_mask"],
    }



# データセットをJSONからロード
import json

with open("formatted_rushia_tweets.json", "r", encoding='utf-8') as f:
    loaded_data = json.load(f)

print("データ数:", len(data))


# プロンプトテンプレートの準備
def generate_prompt(data_point):
    result = f"""### 指示:
{data_point["input"]}

### 回答:
{data_point["completion"]}
"""
    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result


# データセットの準備
VAL_SET_SIZE = 1000

train_dataset = []
val_dataset = []

for i in range(len(data)):
    if i % 5 == 0:
        x = tokenize(generate_prompt(data[i]), tokenizer)
        val_dataset.append(x)
    else:
        x = tokenize(generate_prompt(data[i]), tokenizer)
        train_dataset.append(x)

from transformers import AutoModelForCausalLM

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

# LoRAのパラメータ
lora_config = LoraConfig(
    r= 8,
    lora_alpha=16,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# モデルの前処理
model = prepare_model_for_int8_training(model)

# LoRAモデルの準備
model = get_peft_model(model, lora_config)

# 学習可能パラメータの確認
model.print_trainable_parameters()

import transformers
eval_steps = 1000
save_steps = 100
logging_steps = 100

# トレーナーの準備
trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    args=transformers.TrainingArguments(
        num_train_epochs=7,
        learning_rate=3e-4,
        logging_steps=logging_steps,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=eval_steps,
        save_steps=save_steps,
        output_dir=output_dir,
        save_total_limit=3,
        push_to_hub=False,
        auto_find_batch_size=True
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# 学習の実行
model.config.use_cache = False
trainer.train()
model.config.use_cache = True

# LoRAモデルの保存
trainer.model.save_pretrained(peft_name)

推論コード

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo"
peft_name = "lorappo-rinna-3.6b"
output_dir = "lorappo-rinna-3.6b-rushia-results"

# モデルの準備
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    device_map="auto",
)

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# LoRAモデルの準備
model = PeftModel.from_pretrained(
    model,
    peft_name,
    # device_map="auto"
)

# 評価モード
model.eval()


# プロンプトテンプレートの準備
def generate_prompt(data_point):
    if data_point["input"]:
        result = f"""### 指示:
{data_point["instruction"]}

### 入力:
{data_point["input"]}

### 回答:
"""
    else:
        result = f"""### 指示:
{data_point["instruction"]}

### 回答:
"""

    # 改行→<NL>
    result = result.replace('\n', '<NL>')
    return result


# テキスト生成関数の定義
def generate(instruction, input=None, maxTokens=256) -> str:
    # 推論
    prompt = generate_prompt({'instruction': instruction, 'input': input})
    input_ids = tokenizer(prompt,
                          return_tensors="pt",
                          truncation=True,
                          add_special_tokens=False).input_ids.cuda()
    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=maxTokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.75,
        top_k=40,
        no_repeat_ngram_size=2,
    )
    outputs = outputs[0].tolist()
    # print(tokenizer.decode(outputs))

    # EOSトークンにヒットしたらデコード完了
    if tokenizer.eos_token_id in outputs:
        eos_index = outputs.index(tokenizer.eos_token_id)
        decoded = tokenizer.decode(outputs[:eos_index])

        # レスポンス内容のみ抽出
        sentinel = "### 回答:"
        sentinelLoc = decoded.find(sentinel)
        if sentinelLoc >= 0:
            result = decoded[sentinelLoc + len(sentinel):]
            return result.replace("<NL>", "")  # <NL>→改行
        else:
            return 'Warning: Expected prompt template to be emitted.  Ignoring output.'
    else:
        return 'Warning: no <eos> detected ignoring output'


# テキスト生成
print("自然言語処理ってさ、{0}".format(generate('自然言語処理ってさ、')))
print("台風近づいてるなぁ。{0}".format(generate('台風近づいてるなぁ。')))

学習したモデルでツイートの続きを考えてもらうと以下のようになりました