ハイブリッド検索アプローチ「BM42」を動かしてみる

初めに

以下でBM25よりも精度がいいBM42が発表されたとあるので、実際に触ってみます

www.atpartners.co.jp

以下の記事で、過去にBM25を動かしています。

ayousanz.hatenadiary.jp

以下で今回の記事のリポジトリを公開しています

github.com

開発環境

ライブラリのインストール

pip install numpy
pip install transformers
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

BM42のindexおよび検索

以下でドキュメントのindexと検索をしてみます

from typing import List, Dict
import math
from transformers import AutoTokenizer, AutoModel
import torch

# サンプルデータセット
documents = [
    "Hello world is a common phrase in programming",
    "Python is a popular programming language",
    "Vector databases are useful for similarity search",
    "Machine learning models can be complex",
]

def compute_idf(term: str, documents: List[str]) -> float:
    doc_freq = sum(1 for doc in documents if term in doc.lower())
    return math.log((len(documents) - doc_freq + 0.5) / (doc_freq + 0.5) + 1)

def get_bm42_weights(text: str, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    attentions = outputs.attentions[-1][0, :, 0].mean(dim=0)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    word_weights = {}
    current_word = ""
    current_weight = 0

    for token, weight in zip(tokens[1:-1], attentions[1:-1]):  # Exclude [CLS] and [SEP]
        if token.startswith("##"):
            current_word += token[2:]
            current_weight += weight
        else:
            if current_word:
                word_weights[current_word] = current_weight
            current_word = token
            current_weight = weight

    if current_word:
        word_weights[current_word] = current_weight

    return word_weights

# モデルとトークナイザーの初期化
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

def compute_bm42_score(query: str, document: str, documents: List[str]) -> float:
    query_weights = get_bm42_weights(query, model, tokenizer)
    doc_weights = get_bm42_weights(document, model, tokenizer)

    score = 0
    for term, query_weight in query_weights.items():
        if term in doc_weights:
            idf = compute_idf(term, documents)
            score += query_weight * doc_weights[term] * idf

    return score

def search_bm42(query: str, documents: List[str]) -> List[Dict[str, float]]:
    scores = []
    for doc in documents:
        score = compute_bm42_score(query, doc, documents)
        scores.append({"document": doc, "score": score})

    return sorted(scores, key=lambda x: x["score"], reverse=True)

# 使用例
query = "programming language"

print("BM42 Results:")
for result in search_bm42(query, documents):
    print(f"Score: {result['score']:.4f} - {result['document']}")

結果は以下のようになります

BM25 Results:
Score: 1.9970 - Python is a popular programming language
Score: 0.6398 - Hello world is a common phrase in programming
Score: 0.0000 - Vector databases are useful for similarity search
Score: 0.0000 - Machine learning models can be complex

BM42 Results:
BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.
Score: 0.0117 - Python is a popular programming language
Score: 0.0069 - Hello world is a common phrase in programming
Score: 0.0000 - Vector databases are useful for similarity search
Score: 0.0000 - Machine learning models can be complex

検索エンジンのBM25-rankを試す

開発環境

ライブラリのインストール

以下のドキュメントにあるようにインストールをします

pip install rank_bm25

pypi.org

ドキュメントから関連文の抽出

まずはいくつかの文章をindexにします

from rank_bm25 import BM25Okapi

corpus = [
    "Hello there good man!",
    "It is quite windy in London",
    "How is the weather today?"
]

tokenized_corpus = [doc.split(" ") for doc in corpus]

bm25 = BM25Okapi(tokenized_corpus)

次に 指定した文に近いものを探します

query = "windy London"
tokenized_query = query.split(" ")

doc_scores = bm25.get_scores(tokenized_query)
print(doc_scores)

token_n = bm25.get_top_n(tokenized_query, corpus, n=1)
print(token_n)

結果は以下のように返ってきます

[0.         0.93729472 0.        ]
['It is quite windy in London']

Linuxで7zファイルをまとめて解凍する

開発環境

ライブラリのインストール

以下で7zの解凍するためのライブラリを入れます

sudo apt-get install p7zip-full

7zファイルの解凍

以下でフォルダ内にある7zファイルを解凍します

for file in *.7z; do 7z x "$file"; done

パスワードが設定されている場合は、以下でパスワードを入れて実行できます

for file in *.7z; do 7z x -p"pass word" "$file"; done

Spatial Reality Display ELF-SR2のセットアップをする

初めに

ELF-2を触らせていただく機会があったので、とりあえずセットアップ方法をまとめます

開発環境

SDKのインストール

以下から開発用のSDKをインストールします

www.sony.net

Unity プラグインのダウンロード

以下からUnityのプラグインをダウンロードすることができます

www.sony.net

ディスプレイの設定

以下のサイトの内容を確認して、設定を行う

  • ディスプレイのサイズ設定
  • ディスプレイの位置

knowledge.support.sony.jp

StableTTSでつくよみちゃんコーパスを使ってfine tuingをする

初めに

以下の記事でStableTTSで推論をしてみました。今回はfine tuingを行ってみます。

ayousanz.hatenadiary.jp

環境

  • L4 GPU
  • ubuntu22.04

準備

この辺は前の記事を同じですが、一応書いておきます

ライブラリのインストール

まずは動かすために必要なライブラリをインストールします

内部で音声周りの処理をするために依存しているライブラリを入れます

sudo apt update
sudo apt install ffmpeg

次に requirements.txt を入れるのですが、numpyが2.0.0が公開された影響でバグるので requirements.txt を変更して numpy<2にします

以下が変更した requirements.txtです

torch
torchaudio
matplotlib
numpy<2
tensorboard
pypinyin
jieba
eng_to_ipa
unidecode
inflect
pyopenjtalk-prebuilt
numba
tqdm
IPython
gradio
soundfile

モデルのアップロード

TTSの音声合成をするために、公式のモデルをダウンロードして配置をします。ただし日本語の事前学習モデルは公式から提供されていないため、今回は英語のモデルを使用します

事前学習モデルの取得URL

上記のURLから vocoder.ptcheckpoint-en_0.pt をダウンロードします。この際に checkpoint-en_0.ptcheckpoint_0.pt に名前を変更して、checkpoints/ 以下に配置します。また vocoder.ptはルートパスに配置します

プロジェクトルート/
│
├── checkpoints/
│   └── checkpoint_0.pt  (元の名前: checkpoint-en_0.pt)
│
└── vocoder.pt

音声ファイルのアップロードと学習用テキストの作成

fine tuingをするために、今回は つくよみちゃんコーパスを使用していきます。

以下の構成で、音声ファイル及び学習用のテキストを配置します

StableTTS/
│
├── filelists/
│   └── filelist.txt
│
└── audio/
    ├── VOICEACTRESS100_001.wav
    ├── VOICEACTRESS100_002.wav
    ├── VOICEACTRESS100_003.wav
    ├── VOICEACTRESS100_004.wav
    ├── VOICEACTRESS100_005.wav
    ├── VOICEACTRESS100_006.wav
    └── ....

また filelist.txt の中身は、以下のような書き方でファイルを作成をします

audio/VOICEACTRESS100_001.wav|また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。
audio/VOICEACTRESS100_002.wav|ニューイングランド風は、牛乳をベースとした、白いクリームスープであり、ボストンクラムチャウダーとも呼ばれる。
audio/VOICEACTRESS100_003.wav|コンピュータゲームのメーカーや、業界団体などに関連する人物のカテゴリ。

前処理

音声のアップロードとテキストの配置が終わったら、前処理を実行します。

python preprocess.py

これが終了した際に、filelistの中に filelist.json が作成されていて以下のような内容になっています

{"mel_path": "./stableTTS_datasets/mels/0_VOICEACTRESS100_001.pt", "phone": ["m", "a", "", "t", "a", ",", " ", "t", "o", "", "o", "d", "ʑ", "i", "n", "o", " ", "j", "o", "", "o", "n", "^", "i", ",", " ", "g", "o", "", "d", "a", "i", " ", "m", "j", "o", "", "o", "o", "", "o", "t", "o", " ", "j", "o", "", "b", "a", "ɾ", "e", "ɾ", "ɯ", ",", " ", "ʃ", "ɯ", "", "j", "o", "o", "n", "a", " ", "m", "j", "o", "", "o", "o", "", "o", "n", "o", " ", "t", "ʃ", "ɯ", "", "ɯ", "o", "", "o", "n", "^", "i", " ", "h", "a", "", "i", "s", "a", "ɾ", "e", "ɾ", "ɯ", " ", "k", "o", "", "t", "o", "", "m", "o", " ", "o", "", "o", "i", "."], "audio_path": "audio/VOICEACTRESS100_001.wav", "text": "また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。"}

fine tuningの実行

マルチGPUを使う際には、train.py の2行目の以下をコメントアウトを外して任意のGPUを指定してください

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

学習時のパラメータは config.py になります。バッチサイズは、VRAM 24GBの場合は、48で良さそうです

以下のコマンドを実行することで、学習をすることができます

python train.py

学習が終了後は以下のようなモデルが checkpointsに生成されています。

  • checkpoint_ステップ数.pt
  • optimizer_ステップ数.pt

fine tuingモデルで推論

vocoder.pt を checkpointsに移動します。その後に以下を実行して、WebUIを起動します

python webui.py

学習後に指定するモデルは、checkpoint_ステップ数.pt を指定します。

事前学習をする場合

事前学習をする場合は、fine tuingをする際に配置した事前学習モデルを削除して学習をすれば事前学習ができます

StableTTSで音声合成を試す

初めに

拡散モデルのTTSで(一応)日本語対応されているライブラリの StableTTSを触っていきます

github.com

環境

  • L4 GPU
  • ubuntu22.04

ライブラリのインストール

まずは動かすために必要なライブラリをインストールします

内部で音声周りの処理をするために依存しているライブラリを入れます

sudo apt update
sudo apt install ffmpeg

次に requirements.txt を入れるのですが、numpyが2.0.0が公開された影響でバグるので requirements.txt を変更して numpy<2にします

以下が変更した requirements.txtです

torch
torchaudio
matplotlib
numpy<2
tensorboard
pypinyin
jieba
eng_to_ipa
unidecode
inflect
pyopenjtalk-prebuilt
numba
tqdm
IPython
gradio
soundfile

モデルのアップロード

TTSの音声合成をするために、公式のモデルをダウンロードして配置をします。ただし日本語の事前学習モデルは公式から提供されていないため、今回は英語のモデルを使用します

事前学習モデルの取得URL

上記のURLから vocoder.ptcheckpoint-en_0.pt をダウンロードして、checkpoints/ 以下に配置します

WebUI画面の起動

以下にて、Web画面を起動できます

python webui.py

音声合成

以下を入力して、Sendを押すことで音声合成ができます。Step数を上げることでより精度が上がります(多分)

  • Input text(合成するテキスト文)
  • Reference Speaker(参考にする音声)
  • Language(言語設定)

合成される音声は、英語の事前学習モデルを使っているからか精度は良くないです

推論時間の計測

音声の推論時間を以下のコード差分により計測を行いました

import time

def inference(text: str, ref_audio: torch.Tensor, language: str, checkpoint_path: str, step: int=10) -> torch.Tensor:
    start_time = time.time()  # 関数の開始時間を記録

    global last_checkpoint_path
    if checkpoint_path != last_checkpoint_path:
        load_start = time.time()
        tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) 
        last_checkpoint_path = checkpoint_path
        load_end = time.time()
        print(f"Model loading time: {load_end - load_start:.2f} seconds")
        
    phonemizer = g2p_mapping.get(language)
    
    prep_start = time.time()
    # prepare input for tts model
    x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
    x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
    waveform, sr = torchaudio.load(ref_audio)
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
    y = mel_extractor(waveform).to(device)
    prep_end = time.time()
    print(f"Input preparation time: {prep_end - prep_start:.2f} seconds")
    
    # inference
    inference_start = time.time()
    mel = tts_model.synthesise(x, x_len, step, y=y, temperature=1, length_scale=1)['decoder_outputs']
    audio = vocoder(mel)
    inference_end = time.time()
    print(f"Inference time: {inference_end - inference_start:.2f} seconds")
    
    # process output for gradio
    post_start = time.time()
    audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
    mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
    post_end = time.time()
    print(f"Post-processing time: {post_end - post_start:.2f} seconds")

    end_time = time.time()  # 関数の終了時間を記録
    total_time = end_time - start_time
    print(f"Total execution time: {total_time:.2f} seconds")

    return audio_output, mel_output
Input preparation time: 1.10 seconds
Inference time: 1.88 seconds
Post-processing time: 0.04 seconds
Total execution time: 3.08 seconds
Input preparation time: 0.08 seconds
Inference time: 0.23 seconds
Post-processing time: 0.03 seconds
Total execution time: 0.34 seconds

備考

動いた時点の各ライブラリのverリストです

Package                   Version
------------------------- -----------
absl-py                   2.1.0
aiofiles                  23.2.1
altair                    5.3.0
annotated-types           0.7.0
anyio                     4.4.0
asttokens                 2.4.1
attrs                     23.2.0
certifi                   2024.6.2
cffi                      1.16.0
charset-normalizer        3.3.2
click                     8.1.7
contourpy                 1.2.1
cycler                    0.12.1
Cython                    3.0.10
decorator                 5.1.1
dnspython                 2.6.1
email_validator           2.2.0
eng-to-ipa                0.0.2
exceptiongroup            1.2.1
executing                 2.0.1
fastapi                   0.111.0
fastapi-cli               0.0.4
ffmpy                     0.3.2
filelock                  3.15.3
fonttools                 4.53.0
fsspec                    2024.6.0
gradio                    4.36.1
gradio_client             1.0.1
grpcio                    1.64.1
h11                       0.14.0
httpcore                  1.0.5
httptools                 0.6.1
httpx                     0.27.0
huggingface-hub           0.23.4
idna                      3.7
importlib_resources       6.4.0
inflect                   7.3.0
ipython                   8.25.0
jedi                      0.19.1
jieba                     0.42.1
Jinja2                    3.1.4
jsonschema                4.22.0
jsonschema-specifications 2023.12.1
kiwisolver                1.4.5
llvmlite                  0.43.0
Markdown                  3.6
markdown-it-py            3.0.0
MarkupSafe                2.1.5
matplotlib                3.9.0
matplotlib-inline         0.1.7
mdurl                     0.1.2
more-itertools            10.3.0
mpmath                    1.3.0
networkx                  3.3
numba                     0.60.0
numpy                     1.26.4
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         8.9.2.26
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-nccl-cu12          2.20.5
nvidia-nvjitlink-cu12     12.5.40
nvidia-nvtx-cu12          12.1.105
orjson                    3.10.5
packaging                 24.1
pandas                    2.2.2
parso                     0.8.4
pexpect                   4.9.0
pillow                    10.3.0
pip                       22.0.2
prompt_toolkit            3.0.47
protobuf                  4.25.3
ptyprocess                0.7.0
pure-eval                 0.2.2
pycparser                 2.22
pydantic                  2.7.4
pydantic_core             2.18.4
pydub                     0.25.1
Pygments                  2.18.0
pyopenjtalk-prebuilt      0.3.0
pyparsing                 3.1.2
pypinyin                  0.51.0
python-dateutil           2.9.0.post0
python-dotenv             1.0.1
python-multipart          0.0.9
pytz                      2024.1
PyYAML                    6.0.1
referencing               0.35.1
requests                  2.32.3
rich                      13.7.1
rpds-py                   0.18.1
ruff                      0.4.10
semantic-version          2.10.0
setuptools                59.6.0
shellingham               1.5.4
six                       1.16.0
sniffio                   1.3.1
soundfile                 0.12.1
stack-data                0.6.3
starlette                 0.37.2
sympy                     1.12.1
tensorboard               2.17.0
tensorboard-data-server   0.7.2
tomlkit                   0.12.0
toolz                     0.12.1
torch                     2.3.1
torchaudio                2.3.1
tqdm                      4.66.4
traitlets                 5.14.3
triton                    2.3.1
typeguard                 4.3.0
typer                     0.12.3
typing_extensions         4.12.2
tzdata                    2024.1
ujson                     5.10.0
Unidecode                 1.3.8
urllib3                   2.2.2
uvicorn                   0.30.1
uvloop                    0.19.0
watchfiles                0.22.0
wcwidth                   0.2.13
websockets                11.0.3
Werkzeug                  3.0.3

時系列基盤amazon chronos-t5をサンプルデータでfine tuningをする

初めに

過去に 比較的小さい時系列基盤モデルを触りました。こちらはfine tuingはできなかったので、fine tuingができてより大きいモデルを触っていきます

ayousanz.hatenadiary.jp

環境

  • L4 GPU
  • ubuntu22.04
  • jupter notebook

ライブラリのインストール

# library install
!pip install git+https://github.com/amazon-science/chronos-forecasting.git

データの取得と分析

以下でデータを取得・分析をしていきます

import pandas as pd

# データ読み込み
# https://github.com/zhouhaoyi/ETDataset/blob/main/ETT-small/ETTh1.csv
df = pd.read_csv("ETTh1.csv")
print(len(df))
df.head(2)

以下のようにデータの中が表示されます

データ形式の変換

推論に用いれる形に変換をします

import torch

context_length = 512
forecast_horizon = 96

# データセット分割
df_train = df.iloc[-(context_length+forecast_horizon):-forecast_horizon]
df_test = df.iloc[-forecast_horizon:]

# 形式の変更
train_tensor = torch.tensor(df_train[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
train_tensor = train_tensor.t()
test_tensor = torch.tensor(df_test[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
test_tensor = test_tensor.t()

推論

モデルのロード

以下でモデルをロードします

import pandas as pd
import torch
from chronos import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="cuda",  # use "cpu" for CPU inference and "mps" for Apple Silicon
    torch_dtype=torch.bfloat16,
)

推論実行及びグラフにプロット

forecast = pipeline.predict(train_tensor, forecast_horizon, limit_prediction_length=False)
forecast_median_tensor, _ = torch.median(forecast, dim=1)

import matplotlib.pyplot as plt

channel_idx = 6
time_index = 0

history = train_tensor[channel_idx, :].detach().numpy()
true = test_tensor[channel_idx, :].detach().numpy()
pred = forecast_median_tensor[channel_idx, :].detach().numpy()

plt.figure(figsize=(12, 4))

# Plotting the first time series from history
plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')

# Plotting ground truth and prediction
num_forecasts = len(true)

offset = len(history)
plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (96 timesteps)', color='darkblue', linestyle='--', alpha=0.5)
plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (96 timesteps)', color='red', linestyle='--')

plt.title(f"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})", fontsize=18)
plt.xlabel('Time', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.legend(fontsize=14)
plt.show()

以下のようにゼロショットで推論ができました。かなり一致しています

fine tuing

ライブラリのインストール

!pip install "chronos[training] @ git+https://github.com/amazon-science/chronos-forecasting.git"
!git clone https://github.com/amazon-science/chronos-forecasting.git

# Clone the ibm/tsfm
! git clone https://github.com/IBM/tsfm.git

# Change directory. Move inside the tsfm repo.
%cd tsfm

# Install the tsfm library
! pip install ".[notebooks]"

学習用のデータの変換

from pathlib import Path
from typing import List, Optional, Union

import numpy as np
from gluonts.dataset.arrow import ArrowWriter

from pathlib import Path
from typing import List, Optional, Union

import numpy as np
from gluonts.dataset.arrow import ArrowWriter


def convert_to_arrow(
    path: Union[str, Path],
    time_series: Union[List[np.ndarray], np.ndarray],
    start_times: Optional[Union[List[np.datetime64], np.ndarray]] = None,
    compression: str = "lz4",
):
    if start_times is None:
        # Set an arbitrary start time
        start_times = [np.datetime64("2000-01-01 00:00", "s")] * len(time_series)

    assert len(time_series) == len(start_times)

    dataset = [
        {"start": start, "target": ts} for ts, start in zip(time_series, start_times)
    ]
    ArrowWriter(compression=compression).write_to_file(
        dataset,
        path=path,
    )


# Convert to GluonTS arrow format
cols = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
convert_to_arrow(
    path = "./etth1-train.arrow", 
    time_series=[np.array(df_train[col]) for col in cols],
    start_times=[pd.to_datetime(df_train["date"]).values[0]] * len(cols),
)

以下のようなファイルが生成されます

参考サイト

hamaruki.com

note.com

fine tuing

以下で追加学習をします

import yaml

batch_size = 2  # バッチサイズ
num_steps = train_size/batch_size
print("steps:" + str(num_steps))

# Fine Tuningの設定
config_data = {
    'training_data_paths': ["./etth1-train.arrow"],  # 学習データファイルのパス
    'probability': [1.0],
    'output_dir': './output/',  # 学習結果の出力ディレクトリ
    'context_length': context_length,
    'prediction_length': forecast_horizon,
    'max_steps': num_steps,
    'per_device_train_batch_size': batch_size,
    'learning_rate': 0.001,
    'model_id': 'amazon/chronos-t5-large',
    # 'model_id': 'amazon/chronos-t5-base',
    'random_init': False,  # 事前学習済みモデルを使用
    'tf32': True,        # NVIDIA GPUの場合Trueにする
    'gradient_accumulation_steps':2,
}

# 設定ファイルをYAML形式で保存
config_file_path = './ft_config.yaml'
with open(config_file_path, 'w') as file:
    yaml.dump(config_data, file)

def fine_tune_chronos(train_file_path, config_file_path):
    """
    chronos-t5モデルをFine Tuningする関数

    Args:
        train_file_path (str): 学習用データファイルのパス
        config_file_path (str): Fine Tuning設定ファイルのパス
    """

    # Fine Tuningの実行
    !CUDA_VISIBLE_DEVICES=0 python chronos-forecasting/scripts/training/train.py --config {config_file_path}

# Fine Tuningの実行
fine_tune_chronos("./etth1-train.arrow", config_file_path)

ログは以下のようになりました

{'loss': 3.8517, 'grad_norm': 1.0407471656799316, 'learning_rate': 0.0009282433983926521, 'epoch': 0.07}
{'loss': 3.7261, 'grad_norm': 0.7981559038162231, 'learning_rate': 0.0008564867967853042, 'epoch': 0.14}
{'loss': 3.6565, 'grad_norm': 1.0834617614746094, 'learning_rate': 0.0007847301951779565, 'epoch': 0.22}
{'loss': 3.5609, 'grad_norm': 0.6380050182342529, 'learning_rate': 0.0007129735935706086, 'epoch': 0.29}
{'loss': 3.4786, 'grad_norm': 1.107723355293274, 'learning_rate': 0.0006412169919632607, 'epoch': 0.36}
{'loss': 3.3614, 'grad_norm': 0.8139169812202454, 'learning_rate': 0.0005694603903559128, 'epoch': 0.43}
{'loss': 3.2374, 'grad_norm': 0.9819308519363403, 'learning_rate': 0.0004977037887485649, 'epoch': 0.5}
{'loss': 3.1432, 'grad_norm': 1.3025404214859009, 'learning_rate': 0.000425947187141217, 'epoch': 0.57}
{'loss': 3.0098, 'grad_norm': 1.464908242225647, 'learning_rate': 0.0003541905855338691, 'epoch': 0.65}
{'loss': 2.8967, 'grad_norm': 1.4902337789535522, 'learning_rate': 0.00028243398392652127, 'epoch': 0.72}
{'loss': 2.7994, 'grad_norm': 1.2662851810455322, 'learning_rate': 0.00021067738231917335, 'epoch': 0.79}
{'loss': 2.7279, 'grad_norm': 1.741388201713562, 'learning_rate': 0.0001389207807118255, 'epoch': 0.86}
{'loss': 2.6553, 'grad_norm': 1.594014286994934, 'learning_rate': 6.716417910447761e-05, 'epoch': 0.93}
{'train_runtime': 7191.7126, 'train_samples_per_second': 3.876, 'train_steps_per_second': 0.969, 'train_loss': 3.195165220275949, 'epoch': 1.0}
100%|█████████████████████████████████████| 6968/6968 [1:59:51<00:00,  1.03s/it]

追加学習モデルを使った推論

from chronos import ChronosPipeline
import matplotlib.pyplot as plt

def predict_with_chronos(train_tensor, forecast_horizon, model_name="amazon/chronos-t5-large", device_map="cuda"):
    """
    chronos-t5モデルで予測を行う関数

    Args:
        train_tensor (torch.Tensor): 学習用データテンソル
        forecast_horizon (int): 予測する長さ
        model_name (str): モデル名 (デフォルト: "amazon/chronos-t5-large")
        device_map (str): デバイス ("cuda" or "cpu")

    Returns:
        forecast_median_tensor (torch.Tensor): 予測結果の中央値テンソル
    """

    pipeline = ChronosPipeline.from_pretrained(
        model_name,
        device_map=device_map,
        torch_dtype=torch.bfloat16,  # 計算精度を指定
    )

    # 予測の実行 (limit_prediction_length=Falseで予測長を制限しない)
    forecast = pipeline.predict(train_tensor, forecast_horizon, limit_prediction_length=False)
    forecast_median_tensor, _ = torch.median(forecast, dim=1)  # 予測結果の中央値を計算

    return forecast_median_tensor, forecast

def visualize_prediction(train_tensor, test_tensor, forecast_median_tensor, channel_idx=6):
    """
    予測結果を可視化する関数

    Args:
        train_tensor (torch.Tensor): 学習用データテンソル
        test_tensor (torch.Tensor): テスト用データテンソル
        forecast_median_tensor (torch.Tensor): 予測結果の中央値テンソル
        channel_idx (int): 可視化するチャンネルのインデックス (デフォルト: 6, OT)
    """

    history = train_tensor[channel_idx, :].detach().numpy()  # 学習用データ
    true = test_tensor[channel_idx, :].detach().numpy()      # 実測値
    pred = forecast_median_tensor[channel_idx, :].detach().numpy()  # 予測値

    plt.figure(figsize=(12, 4))
    plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')
    plt.plot(range(len(history), len(history) + len(true)), true, label='Ground Truth (96 timesteps)', color='darkblue', linestyle='--', alpha=0.5)
    plt.plot(range(len(history), len(history) + len(pred)), pred, label='Forecast (96 timesteps)', color='red', linestyle='--')
    plt.title(f"ETTh1 (Hourly) - Channel {channel_idx}", fontsize=18)
    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Value', fontsize=14)
    plt.legend(fontsize=14)
    plt.show()

# Fine Tuning後のモデルで予測
forecast_median_tensor_ft, forecast_ft = predict_with_chronos(train_tensor, forecast_horizon=forecast_horizon, model_name="./output/run-4/checkpoint-final/")

# 予測結果の可視化 (Fine Tuning後)
visualize_prediction(train_tensor, test_tensor, forecast_median_tensor_ft)