Google Colobでisek-ai/SDPrompt-RetNet-300Mを動かす

はじめに

こちらのポストを見て、単語からプロンプトを生成していてすごい!?となったので実際に動かしてみます

ライブラリのインストール

!pip install transformers safetensors timm

実装

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

MODEL_NAME = "isek-ai/SDPrompt-RetNet-300M"

DEVICE = "cuda"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
).to(DEVICE)

streamer = TextStreamer(tokenizer)

prompt = "<s>1girl"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

_ = model.generate(
    inputs["input_ids"],
    max_new_tokens=256,
    do_sample=True,
    top_p=0.9,
    top_k=20,
    temperature=0.9,
    streamer=streamer,
)

出力が以下のようになりました

<s> 1girl, blue hair, boots, commentary request, computer, cyberpunk, hatsune miku, highres, huke, long hair, monitor, multiple monitors, skirt, solo, thigh boots, thighhighs, very long hair, vocaloid, zettai ryouiki</s>

別のパターンでも試してみます

prompt = "<s>dog"

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

_ = model.generate(
    inputs["input_ids"],
    max_new_tokens=256,
    do_sample=True,
    top_p=0.9,
    top_k=20,
    temperature=0.9,
    streamer=streamer,
)
<s> dog jumps over hill, dog looks like elephant with trunk!!!!, intricate, epic lighting, cinematic composition, hyper realistic, 8 k resolution, unreal engine 5, by artgerm, tooth wu, dan mumford, beeple, wlop, rossdraws, james jean, marc simonetti, artstation</s

所感

自然言語から単語を抽出してプロンプトにできれば良さそうですが、プロンプトの内容が単語レベルから推測されるものがざっくりしているので、実際には少し調節が必要な感じです