初めに
「一貫したキャラ性を持った回答をするAIを作りたい」
— 松xR (@matsu_vr) 2023年6月8日
「でもライセンスの問題もなくキャラ性を保ったままそれなりの規模があるデータセットなんて無い」
「自分のツイートを使えばいいのでは💡」
というわけで、rinna 3.6bを自分の過去ツイートでfinetuneしたオレッターです。かなり俺っぽい!! pic.twitter.com/3k7CpM5HmL
松xRさんのツイートを使ったfine turningのツイートを見て、キャラクターの再現ができると思い今はなき「潤羽るしあ」のツイートを学習して、ツイートを作れるようにしていきます。
またこちらの記事は、以下の記事を参考にして学習を行っています
環境
fine turning
データの準備
学習するためには、データが必要なため、ツイートデータを集めて以下のようなフォーマットに変換します。
{ "input": "エゴサしてると、", "completion": "お誕生日ライブのことやオリ曲のこと、お歌やダンスのことたくさん呟いてくれてみんなありがとう( ´•̥ ·̭ •̥` )♡うれしすぎる。。。" }, { "input": "主様、", "completion": "おきるにゃ〜🐾 https://t.co/uX28fv6bqT" },
この時に基本参考記事を同じことをしていますが、メンション(ユーザー名)を除くために以下の関数を使って使って整形処理を行っています。
# '@hoge' 形式のユーザーメンションを削除する関数 def remove_mentions(text): return re.sub(r'@\w+\s*', '', text)
以下がツイートデータをfine tuneingで使用するための整形コードです
import json import re # JSONファイルを読み込む with open('correct_tweet_archive.json', 'r', encoding='utf-8') as f: data = json.load(f) # 出力を保存するリストを作成 output = [] # '@hoge' 形式のユーザーメンションを削除する関数 def remove_mentions(text): return re.sub(r'@\w+\s*', '', text) tweet_count = 0 # 各ツイートについて処理を行う for tweet in data: full_text = tweet['data']['text'] # 最初の「、」または「。」で区切る match = re.search(r'(.*?[、。])(.*)', full_text) # 「、」または「。」が見つかった場合 if match: first_sentence = remove_mentions(match.group(1)) remaining_text = remove_mentions(match.group(2)) # 新しいフォーマットに整形 formatted = { "input": first_sentence, "completion": remaining_text } # 結果をリストに追加 tweet_count += 1 print(formatted) output.append(formatted) print(f"{tweet_count} 件のツイートを処理しました。") # 結果を新しいJSONファイルに書き込む with open('formatted_rushia_tweets.json', 'w', encoding='utf-8') as f: json.dump(output, f, indent=4, ensure_ascii=False) print("新しいJSONファイルに書き込みました。")
学習
学習自体は、記事の通りなのですがコードだけ記載しておきます。
必要なライブラリのインストール*
!pip install -Uqq git+https://github.com/huggingface/peft.git !pip install -Uqq transformers datasets accelerate !pip install sentencepiece !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 !pip install scipy
学習コード
# 基本パラメータ model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo" peft_name = "lorappo-rinna-3.6b" output_dir = "lorappo-rinna-3.6b-rushia-results1" from transformers import AutoTokenizer # トークナイザーの準備 tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) CUTOFF_LEN = 512 # コンテキスト長 # トークナイズ関数 def tokenize(prompt, tokenizer): result = tokenizer( prompt, truncation=True, max_length=CUTOFF_LEN, padding=False, ) return { "input_ids": result["input_ids"], "attention_mask": result["attention_mask"], } # データセットをJSONからロード import json with open("formatted_rushia_tweets.json", "r", encoding='utf-8') as f: loaded_data = json.load(f) print("データ数:", len(data)) # プロンプトテンプレートの準備 def generate_prompt(data_point): result = f"""### 指示: {data_point["input"]} ### 回答: {data_point["completion"]} """ # 改行→<NL> result = result.replace('\n', '<NL>') return result # データセットの準備 VAL_SET_SIZE = 1000 train_dataset = [] val_dataset = [] for i in range(len(data)): if i % 5 == 0: x = tokenize(generate_prompt(data[i]), tokenizer) val_dataset.append(x) else: x = tokenize(generate_prompt(data[i]), tokenizer) train_dataset.append(x) from transformers import AutoModelForCausalLM # モデルの準備 model = AutoModelForCausalLM.from_pretrained( model_name, load_in_8bit=True, device_map="auto", ) from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType # LoRAのパラメータ lora_config = LoraConfig( r= 8, lora_alpha=16, target_modules=["query_key_value"], lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM ) # モデルの前処理 model = prepare_model_for_int8_training(model) # LoRAモデルの準備 model = get_peft_model(model, lora_config) # 学習可能パラメータの確認 model.print_trainable_parameters() import transformers eval_steps = 1000 save_steps = 100 logging_steps = 100 # トレーナーの準備 trainer = transformers.Trainer( model=model, train_dataset=train_dataset, eval_dataset=val_dataset, args=transformers.TrainingArguments( num_train_epochs=7, learning_rate=3e-4, logging_steps=logging_steps, evaluation_strategy="steps", save_strategy="steps", eval_steps=eval_steps, save_steps=save_steps, output_dir=output_dir, save_total_limit=3, push_to_hub=False, auto_find_batch_size=True ), data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False), ) # 学習の実行 model.config.use_cache = False trainer.train() model.config.use_cache = True # LoRAモデルの保存 trainer.model.save_pretrained(peft_name)
推論コード
import torch from peft import PeftModel, PeftConfig from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo" peft_name = "lorappo-rinna-3.6b" output_dir = "lorappo-rinna-3.6b-rushia-results" # モデルの準備 model = AutoModelForCausalLM.from_pretrained( model_name, load_in_8bit=True, device_map="auto", ) # トークナイザーの準備 tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # LoRAモデルの準備 model = PeftModel.from_pretrained( model, peft_name, # device_map="auto" ) # 評価モード model.eval() # プロンプトテンプレートの準備 def generate_prompt(data_point): if data_point["input"]: result = f"""### 指示: {data_point["instruction"]} ### 入力: {data_point["input"]} ### 回答: """ else: result = f"""### 指示: {data_point["instruction"]} ### 回答: """ # 改行→<NL> result = result.replace('\n', '<NL>') return result # テキスト生成関数の定義 def generate(instruction, input=None, maxTokens=256) -> str: # 推論 prompt = generate_prompt({'instruction': instruction, 'input': input}) input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.cuda() outputs = model.generate( input_ids=input_ids, max_new_tokens=maxTokens, do_sample=True, temperature=0.7, top_p=0.75, top_k=40, no_repeat_ngram_size=2, ) outputs = outputs[0].tolist() # print(tokenizer.decode(outputs)) # EOSトークンにヒットしたらデコード完了 if tokenizer.eos_token_id in outputs: eos_index = outputs.index(tokenizer.eos_token_id) decoded = tokenizer.decode(outputs[:eos_index]) # レスポンス内容のみ抽出 sentinel = "### 回答:" sentinelLoc = decoded.find(sentinel) if sentinelLoc >= 0: result = decoded[sentinelLoc + len(sentinel):] return result.replace("<NL>", "") # <NL>→改行 else: return 'Warning: Expected prompt template to be emitted. Ignoring output.' else: return 'Warning: no <eos> detected ignoring output' # テキスト生成 print("自然言語処理ってさ、{0}".format(generate('自然言語処理ってさ、'))) print("台風近づいてるなぁ。{0}".format(generate('台風近づいてるなぁ。')))
学習したモデルでツイートの続きを考えてもらうと以下のようになりました