Qwen2-Audio-7B-Instructで音声とテキストから感情を判定する

初めに

いままで試してきた感情の判定は、音声データもしくはテキストデータの片方から判定を行っていました。 今回はマルチモーダルを用いて両方のデータから判定を行っていきます

開発環境

環境構築

必要なライブラリをインストールします

uv pip install --extra-index-url https://download.pytorch.org/whl/cu124 --pre torch torchvision torchaudio
uv pip install --upgrade git+https://github.com/huggingface/transformers
uv pip install accelerate sentencepiece soundfile

必要に応じて

uv pip install bitsandbytes

判定の実行

指定したラベルのみを返すように 「番号 → 日本語ラベル」で問題なければ 数値ラベル方式を使って実装をしていきます

import argparse, pathlib, torch, torchaudio
from transformers import (AutoProcessor, Qwen2AudioForConditionalGeneration,
                          LogitsProcessor, LogitsProcessorList)

MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
LABELS   = ["喜び", "怒り", "悲しみ", "恐れ", "驚き", "嫌悪", "中立", "その他"]
NUM_IDS  = list(range(len(LABELS)))          # [0,1,2,3,4,5,6,7]

# ---------- 1) CLI ----------
ap = argparse.ArgumentParser()
ap.add_argument("audio")
ap.add_argument("text")
ap.add_argument("--8bit", action="store_true")
args = ap.parse_args()

# ---------- 2) processor / model ----------
proc = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
tok  = proc.tokenizer
load_kw = dict(device_map="auto", trust_remote_code=True)
load_kw["load_in_8bit" if args.__dict__["8bit"] else "torch_dtype"] = (
    True if args.__dict__["8bit"] else torch.float16
)
model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID, **load_kw).eval()

# ---------- 3) 音声 ----------
wave, sr = torchaudio.load(args.audio)
target_sr = proc.feature_extractor.sampling_rate
if sr != target_sr:
    wave = torchaudio.functional.resample(wave, sr, target_sr)
wav_np = wave.squeeze().cpu().numpy()

# ---------- 4) プロンプト ----------
system_msg = (
    "日本語の感情を分類します。"
    "次の数値から **1 つだけ**返してください。\n"
    + "\n".join(f"{i}:{lbl}" for i, lbl in enumerate(LABELS))
)
prompt = proc.apply_chat_template(
    [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": [
            {"type": "audio", "audio": wav_np},
            {"type": "text",  "text": args.text},
            {"type": "text",  "text": "感情番号は?"}
        ]}
    ],
    tokenize=False, add_generation_prompt=True,
)

inputs = proc(text=prompt, audio=wav_np, sampling_rate=target_sr,
              return_tensors="pt").to(model.device)

# ---------- 5) ロジット制限 (0〜7 のみ) ----------
class AllowedNums(LogitsProcessor):
    def __init__(self, ids): self.ids = torch.tensor(ids)
    def __call__(self, input_ids, scores):
        mask = torch.full_like(scores, float("-inf"))
        mask[:, self.ids] = 0
        return scores + mask

proc_list = LogitsProcessorList([AllowedNums(NUM_IDS)])

gen = model.generate(**inputs,
                     logits_processor=proc_list,
                     max_new_tokens=1,
                     temperature=0.0, do_sample=False)

num_id = gen[0, -1].item()
print(LABELS[num_id])          # 最終的に日本語ラベルで表示

以下のように実行します

python .\run_audio.py '.\Kanjyou Ikari11.wav' "イタリア旅行で彼は、いくつか景勝の地として有名な都市、例えば、ナポリやフィレンツェを訪れた。"