Google Colabで時系列基盤モデルのGoogle timesfmを試す

初めに

時系列基盤モデルでどんなことができるのか気になったので、以下の記事を実際に試してみます

note.com

開発環境

ライブラリのインストール

ライブラリをインストールします

!pip install utilsforecast

データのダウンロードおよび整理

import pandas as pd

# データ読み込み
# https://github.com/zhouhaoyi/ETDataset/blob/main/ETT-small/ETTh1.csv
df = pd.read_csv("ETTh1.csv")
print(len(df))
df.head(2)

import torch

context_length = 512
forecast_horizon = 96

# データセット分割
df_train = df.iloc[-(context_length+forecast_horizon):-forecast_horizon]
df_test = df.iloc[-forecast_horizon:]

# 形式の変更
train_tensor = torch.tensor(df_train[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
train_tensor = train_tensor.t()
test_tensor = torch.tensor(df_test[["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]].values, dtype=torch.float)
test_tensor = test_tensor.t()

モデルのロード

以下でモデルでロードします

import timesfm

tfm = timesfm.TimesFm(
    context_len=context_length,
    horizon_len=forecast_horizon,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280,
    backend="gpu",
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

推論

以下で推論します

# 予測の実行
frequency_input = [0] * train_tensor.size(0)
point_forecast, experimental_quantile_forecast = tfm.forecast(
    train_tensor,
    freq=frequency_input,
)
forecast_tensor = torch.tensor(point_forecast)
quantile_tensor = torch.tensor(experimental_quantile_forecast)

import matplotlib.pyplot as plt

channel_idx = 6
time_index = 0

history = train_tensor[channel_idx, :].detach().numpy()
true = test_tensor[channel_idx, :].detach().numpy()
pred = forecast_tensor[channel_idx, :].detach().numpy()

plt.figure(figsize=(12, 4))

# Plotting the first time series from history
plt.plot(range(len(history)), history, label='History (512 timesteps)', c='darkblue')

# Plotting ground truth and prediction
num_forecasts = len(true)

offset = len(history)
plt.plot(range(offset, offset + len(true)), true, label='Ground Truth (96 timesteps)', color='darkblue', linestyle='--', alpha=0.5)
plt.plot(range(offset, offset + len(pred)), pred, label='Forecast (96 timesteps)', color='red', linestyle='--')

plt.title(f"ETTh1 (Hourly) -- (idx={time_index}, channel={channel_idx})", fontsize=18)
plt.xlabel('Time', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.legend(fontsize=14)
plt.show()

出力される予測画像は以下になります

各種ライブラリのver

pip list

で出力された各種ライブラリのverです

Package                           Version
--------------------------------- ---------------------
absl-py                           1.4.0
aiohttp                           3.9.5
aiosignal                         1.3.1
alabaster                         0.7.16
albumentations                    1.3.1
altair                            4.2.2
annotated-types                   0.7.0
anyio                             3.7.1
argon2-cffi                       23.1.0
argon2-cffi-bindings              21.2.0
array_record                      0.5.1
arviz                             0.15.1
astropy                           5.3.4
astunparse                        1.6.3
async-timeout                     4.0.3
atpublic                          4.1.0
attrs                             23.2.0
audioread                         3.0.1
autograd                          1.6.2
Babel                             2.15.0
backcall                          0.2.0
beautifulsoup4                    4.12.3
bidict                            0.23.1
bigframes                         1.8.0
bleach                            6.1.0
blinker                           1.4
blis                              0.7.11
blosc2                            2.0.0
bokeh                             3.3.4
bqplot                            0.12.43
branca                            0.7.2
build                             1.2.1
CacheControl                      0.14.0
cachetools                        5.3.3
catalogue                         2.0.10
certifi                           2024.6.2
cffi                              1.16.0
chardet                           5.2.0
charset-normalizer                3.3.2
chex                              0.1.86
click                             8.1.7
click-plugins                     1.1.1
cligj                             0.7.2
cloudpathlib                      0.16.0
cloudpickle                       2.2.1
clu                               0.0.11
cmake                             3.27.9
cmdstanpy                         1.2.3
colorama                          0.4.6
colorcet                          3.1.0
colorlover                        0.3.0
colour                            0.1.5
community                         1.0.0b1
confection                        0.1.5
cons                              0.4.6
contextlib2                       21.6.0
contourpy                         1.2.1
cryptography                      42.0.7
cuda-python                       12.2.1
cudf-cu12                         24.6.0
cufflinks                         0.17.3
cupy-cuda12x                      12.2.0
cvxopt                            1.3.2
cvxpy                             1.3.4
cycler                            0.12.1
cymem                             2.0.8
Cython                            3.0.10
dask                              2023.8.1
datascience                       0.17.6
db-dtypes                         1.2.0
dbus-python                       1.2.18
debugpy                           1.6.6
decorator                         4.4.2
defusedxml                        0.7.1
distributed                       2023.8.1
distro                            1.7.0
dlib                              19.24.4
dm-tree                           0.1.8
docstring_parser                  0.16
docutils                          0.18.1
dopamine_rl                       4.0.9
duckdb                            0.10.3
earthengine-api                   0.1.405
easydict                          1.13
ecos                              2.0.13
editdistance                      0.6.2
eerepr                            0.0.4
einops                            0.7.0
einshape                          1.0
en-core-web-sm                    3.7.1
entrypoints                       0.4
et-xmlfile                        1.1.0
etils                             1.7.0
etuples                           0.3.9
exceptiongroup                    1.2.1
fastai                            2.7.15
fastcore                          1.5.43
fastdownload                      0.0.7
fastjsonschema                    2.19.1
fastprogress                      1.0.3
fastrlock                         0.8.2
fiddle                            0.3.0
filelock                          3.14.0
fiona                             1.9.6
firebase-admin                    5.3.0
Flask                             2.2.5
flatbuffers                       1.12
flax                              0.8.2
folium                            0.14.0
fonttools                         4.53.0
frozendict                        2.4.4
frozenlist                        1.4.1
fsspec                            2023.6.0
future                            0.18.3
gast                              0.4.0
gcsfs                             2023.6.0
GDAL                              3.6.4
gdown                             5.1.0
geemap                            0.32.1
gensim                            4.3.2
geocoder                          1.38.1
geographiclib                     2.0
geopandas                         0.13.2
geopy                             2.3.0
gin-config                        0.5.0
glob2                             0.7
google                            2.0.3
google-ai-generativelanguage      0.6.4
google-api-core                   2.11.1
google-api-python-client          2.84.0
google-auth                       2.27.0
google-auth-httplib2              0.1.1
google-auth-oauthlib              0.4.6
google-cloud-aiplatform           1.52.0
google-cloud-bigquery             3.21.0
google-cloud-bigquery-connection  1.12.1
google-cloud-bigquery-storage     2.25.0
google-cloud-core                 2.3.3
google-cloud-datastore            2.15.2
google-cloud-firestore            2.11.1
google-cloud-functions            1.13.3
google-cloud-iam                  2.15.0
google-cloud-language             2.13.3
google-cloud-resource-manager     1.12.3
google-cloud-storage              2.8.0
google-cloud-translate            3.11.3
google-colab                      1.0.0
google-crc32c                     1.5.0
google-generativeai               0.5.4
google-pasta                      0.2.0
google-resumable-media            2.7.0
googleapis-common-protos          1.63.1
googledrivedownloader             0.4
graph-compression-google-research 0.0.4
graphviz                          0.20.1
greenlet                          3.0.3
grpc-google-iam-v1                0.13.0
grpcio                            1.64.1
grpcio-status                     1.48.2
gspread                           6.0.2
gspread-dataframe                 3.3.1
gym                               0.25.2
gym-notices                       0.0.8
h5netcdf                          1.3.0
h5py                              3.9.0
holidays                          0.50
holoviews                         1.17.1
html5lib                          1.1
httpimport                        1.3.1
httplib2                          0.22.0
huggingface-hub                   0.23.2
humanize                          4.7.0
hyperopt                          0.2.7
ibis-framework                    8.0.0
idna                              3.7
imageio                           2.31.6
imageio-ffmpeg                    0.5.1
imagesize                         1.4.1
imbalanced-learn                  0.10.1
imgaug                            0.4.0
immutabledict                     4.2.0
importlib_metadata                7.1.0
importlib_resources               6.4.0
imutils                           0.5.4
inflect                           7.0.0
iniconfig                         2.0.0
intel-openmp                      2023.2.4
ipyevents                         2.0.2
ipyfilechooser                    0.6.0
ipykernel                         5.5.6
ipyleaflet                        0.18.2
ipython                           7.34.0
ipython-genutils                  0.2.0
ipython-sql                       0.5.0
ipytree                           0.2.2
ipywidgets                        7.7.1
itsdangerous                      2.2.0
jax                               0.4.26
jax-bitempered-loss               0.0.2
jaxlib                            0.4.26+cuda12.cudnn89
jaxtyping                         0.2.28
jedi                              0.19.1
jeepney                           0.7.1
jellyfish                         1.0.4
jieba                             0.42.1
Jinja2                            3.1.4
joblib                            1.4.2
jsonpickle                        3.0.4
jsonschema                        4.19.2
jsonschema-specifications         2023.12.1
jupyter                           1.0.0
jupyter-client                    6.1.12
jupyter-console                   6.1.0
jupyter_core                      5.7.2
jupyter-http-over-ws              0.0.8
jupyter-server                    1.24.0
jupyterlab_pygments               0.3.0
jupyterlab_widgets                3.0.11
kaggle                            1.6.14
kagglehub                         0.2.5
keras                             2.9.0
Keras-Preprocessing               1.1.2
keyring                           23.5.0
kiwisolver                        1.4.5
langcodes                         3.4.0
language_data                     1.2.0
launchpadlib                      1.10.16
lazr.restfulclient                0.14.4
lazr.uri                          1.0.6
lazy_loader                       0.4
libclang                          18.1.1
libcst                            1.4.0
librosa                           0.10.2.post1
lightgbm                          4.1.0
lingvo                            0.12.7
linkify-it-py                     2.0.3
llvmlite                          0.41.1
locket                            1.0.0
logical-unification               0.4.6
lxml                              4.9.4
malloy                            2023.1067
marisa-trie                       1.1.1
Markdown                          3.6
markdown-it-py                    3.0.0
MarkupSafe                        2.1.5
matplotlib                        3.7.1
matplotlib-inline                 0.1.7
matplotlib-venn                   0.11.10
mdit-py-plugins                   0.4.1
mdurl                             0.1.2
mesh-tensorflow                   0.1.21
miniKanren                        1.0.3
missingno                         0.5.2
mistune                           0.8.4
mizani                            0.9.3
mkl                               2023.2.0
ml-collections                    0.1.1
ml-dtypes                         0.4.0
mlxtend                           0.22.0
model-pruning-google-research     0.0.5
more-itertools                    10.1.0
moviepy                           1.0.3
mpmath                            1.3.0
msgpack                           1.0.8
multidict                         6.0.5
multipledispatch                  1.0.0
multitasking                      0.0.11
murmurhash                        1.0.10
music21                           9.1.0
natsort                           8.4.0
nbclassic                         1.1.0
nbclient                          0.10.0
nbconvert                         6.5.4
nbformat                          5.10.4
nest-asyncio                      1.6.0
networkx                          3.3
nibabel                           4.0.2
nltk                              3.8.1
notebook                          6.5.5
notebook_shim                     0.2.4
numba                             0.58.1
numexpr                           2.10.0
numpy                             1.26.4
nvtx                              0.2.10
oauth2client                      4.1.3
oauthlib                          3.2.2
opencv-contrib-python             4.8.0.76
opencv-python                     4.8.0.76
opencv-python-headless            4.10.0.82
openpyxl                          3.1.3
opt-einsum                        3.3.0
optax                             0.2.2
optax-shampoo                     0.0.6
orbax-checkpoint                  0.5.9
osqp                              0.6.2.post8
packaging                         24.0
pandas                            2.0.3
pandas-datareader                 0.10.0
pandas-gbq                        0.19.2
pandas-stubs                      2.0.3.230814
pandocfilters                     1.5.1
panel                             1.3.8
param                             2.1.0
parso                             0.8.4
parsy                             2.1
partd                             1.4.2
pathlib                           1.0.1
patsy                             0.5.6
paxml                             1.4.0
peewee                            3.17.5
pexpect                           4.9.0
pickleshare                       0.7.5
Pillow                            9.4.0
pip                               23.1.2
pip-tools                         6.13.0
platformdirs                      4.2.2
plotly                            5.15.0
plotnine                          0.12.4
pluggy                            1.5.0
polars                            0.20.2
pooch                             1.8.1
portalocker                       2.8.2
portpicker                        1.5.2
praxis                            1.4.0
prefetch-generator                1.0.3
preshed                           3.0.9
prettytable                       3.10.0
proglog                           0.1.10
progressbar2                      4.2.0
prometheus_client                 0.20.0
promise                           2.3
prompt_toolkit                    3.0.45
prophet                           1.1.5
proto-plus                        1.23.0
protobuf                          3.19.6
psutil                            5.9.5
psycopg2                          2.9.9
ptyprocess                        0.7.0
py-cpuinfo                        9.0.0
py4j                              0.10.9.7
pyarrow                           16.1.0
pyarrow-hotfix                    0.6
pyasn1                            0.6.0
pyasn1_modules                    0.4.0
pycocotools                       2.0.7
pycparser                         2.22
pydantic                          2.7.3
pydantic_core                     2.18.4
pydata-google-auth                1.8.2
pydot                             1.4.2
pydot-ng                          2.0.0
pydotplus                         2.0.2
PyDrive                           1.3.1
PyDrive2                          1.6.3
pyerfa                            2.0.1.4
pygame                            2.5.2
pyglove                           0.4.4
Pygments                          2.16.1
PyGObject                         3.42.1
PyJWT                             2.3.0
pymc                              5.10.4
pymystem3                         0.2.0
pynvjitlink-cu12                  0.2.3
PyOpenGL                          3.1.7
pyOpenSSL                         24.1.0
pyparsing                         3.1.2
pyperclip                         1.8.2
pyproj                            3.6.1
pyproject_hooks                   1.1.0
pyshp                             2.3.1
PySocks                           1.7.1
pytensor                          2.18.6
pytest                            7.4.4
python-apt                        0.0.0
python-box                        7.1.1
python-dateutil                   2.8.2
python-louvain                    0.16
python-slugify                    8.0.4
python-utils                      3.8.2
pytz                              2023.4
pyviz_comms                       3.0.2
PyWavelets                        1.6.0
PyYAML                            6.0.1
pyzmq                             24.0.1
qdldl                             0.1.7.post2
qtconsole                         5.5.2
QtPy                              2.4.1
qudida                            0.0.4
ratelim                           0.1.6
referencing                       0.35.1
regex                             2024.5.15
requests                          2.31.0
requests-oauthlib                 1.3.1
requirements-parser               0.9.0
rich                              13.7.1
rmm-cu12                          24.6.0
rouge-score                       0.1.2
rpds-py                           0.18.1
rpy2                              3.4.2
rsa                               4.9
sacrebleu                         2.4.2
safetensors                       0.4.3
scikit-image                      0.19.3
scikit-learn                      1.2.2
scipy                             1.11.4
scooby                            0.10.0
scs                               3.2.4.post2
seaborn                           0.13.1
SecretStorage                     3.3.1
Send2Trash                        1.8.3
sentencepiece                     0.1.99
seqio-nightly                     0.0.17.dev20231010
setuptools                        67.7.2
shapely                           2.0.4
simple_parsing                    0.1.5
six                               1.16.0
sklearn-pandas                    2.2.0
smart-open                        6.4.0
sniffio                           1.3.1
snowballstemmer                   2.2.0
sortedcontainers                  2.4.0
soundfile                         0.12.1
soupsieve                         2.5
soxr                              0.3.7
spacy                             3.7.4
spacy-legacy                      3.0.12
spacy-loggers                     1.0.5
Sphinx                            5.0.2
sphinxcontrib-applehelp           1.0.8
sphinxcontrib-devhelp             1.0.6
sphinxcontrib-htmlhelp            2.0.5
sphinxcontrib-jsmath              1.0.1
sphinxcontrib-qthelp              1.0.7
sphinxcontrib-serializinghtml     1.1.10
SQLAlchemy                        2.0.30
sqlglot                           20.11.0
sqlparse                          0.5.0
srsly                             2.4.8
stanio                            0.5.0
statsmodels                       0.14.2
StrEnum                           0.4.15
sympy                             1.12.1
t5                                0.9.4
tables                            3.8.0
tabulate                          0.9.0
tbb                               2021.12.0
tblib                             3.0.0
tenacity                          8.3.0
tensorboard                       2.9.1
tensorboard-data-server           0.6.1
tensorboard-plugin-wit            1.8.1
tensorflow                        2.9.3
tensorflow-datasets               4.8.3
tensorflow-estimator              2.9.0
tensorflow-gcs-config             2.15.0
tensorflow-hub                    0.16.1
tensorflow-io-gcs-filesystem      0.37.0
tensorflow-metadata               1.12.0
tensorflow-probability            0.23.0
tensorflow-text                   2.9.0
tensorstore                       0.1.55
termcolor                         2.4.0
terminado                         0.18.1
text-unidecode                    1.3
textblob                          0.17.1
tf-keras                          2.15.0
tf-slim                           1.1.0
tfds-nightly                      4.8.3.dev202303280045
thinc                             8.2.3
threadpoolctl                     3.5.0
tifffile                          2024.5.22
timesfm                           0.0.1
tinycss2                          1.3.0
tokenizers                        0.19.1
toml                              0.10.2
tomli                             2.0.1
toolz                             0.12.1
torch                             2.3.0+cu121
torchaudio                        2.3.0+cu121
torchsummary                      1.5.1
torchtext                         0.18.0
torchvision                       0.18.0+cu121
tornado                           6.3.3
tqdm                              4.66.4
traitlets                         5.7.1
traittypes                        0.2.1
transformers                      4.41.2
triton                            2.3.0
tweepy                            4.14.0
typeguard                         2.13.3
typer                             0.9.4
types-pytz                        2024.1.0.20240417
types-setuptools                  70.0.0.20240524
typing_extensions                 4.12.1
tzdata                            2024.1
tzlocal                           5.2
uc-micro-py                       1.0.3
uritemplate                       4.1.1
urllib3                           2.0.7
utilsforecast                     0.1.10
vega-datasets                     0.9.0
wadllib                           1.3.6
wasabi                            1.1.3
wcwidth                           0.2.13
weasel                            0.3.4
webcolors                         1.13
webencodings                      0.5.1
websocket-client                  1.8.0
Werkzeug                          3.0.3
wheel                             0.43.0
widgetsnbextension                3.6.6
wordcloud                         1.9.3
wrapt                             1.14.1
xarray                            2023.7.0
xarray-einstats                   0.7.0
xgboost                           2.0.3
xlrd                              2.0.1
xyzservices                       2024.4.0
yarl                              1.9.4
yellowbrick                       1.5
yfinance                          0.2.40
zict                              3.0.0
zipp                              3.19.1

GoogleColobでstabilityai/stable-audio-open-1.0を動かす

初めに

Audio生成でかなり精度が高いモデルが出たので触ってみます

huggingface.co

開発環境

ライブラリのインストール

# 必要なパッケージのインストール
!pip install torch torchaudio einops stable-audio-tools

シークレットトークンの設定

from google.colab import drive
drive.mount('/content/drive')

import os
from google.colab import auth

# Colabのシークレットに設定したトークンを取得
auth.authenticate_user()
token = os.getenv('HF_TOKEN')

if token is None:
    raise ValueError("HF_TOKENが設定されていません。Colabのシークレット機能を使用して設定してください。")

os.environ['HF_TOKEN'] = token

!pip install huggingface_hub

モデルのロードおよび生成

import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from huggingface_hub import login

# Hugging Faceにログイン
login(token=os.environ['HF_TOKEN'])

# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

# Set up text and timing conditioning
conditioning = [{
    "prompt": "128 BPM tech house drum loop",
    "seconds_start": 0, 
    "seconds_total": 30
}]

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=100,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=sample_size,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)

UnityでReadOnly状態AnimationClipをyamlファイルから変更してEventを追加する

初めに

UnityでAnimation Clipを使ってアニメーションを制御することはよくあります。しかし、こちらを変更する方法として、FBXからD&Dをする方法もありますができない場合もあります(FBXが手元にない場合)。 この場合の対応方法として、animation clipを直接テキストとして変更する方法を今回は使っていきます

ReadOnlyになっているAnimation Clip

開発環境

  • Unity 2022.3.10f1

アニメーションは、以下のアセットを使用しています。

Animationにイベントを追加

アニメーションクリップのyamlの確認

(注) サンプルとして使用するアセットは、通常通り使用できるため ReadOnlyにはなっていません。例として使用しています

まずは、追加したいアニメーションをテキストエディタで開きます。

以下のようなテキストが表示されるため、一番下の行にある m_Events を変更していきます。

%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!74 &7400000
AnimationClip:
  m_ObjectHideFlags: 0
  m_CorrespondingSourceObject: {fileID: 0}
  m_PrefabInstance: {fileID: 0}
  m_PrefabAsset: {fileID: 0}
  m_Name: Walk_wing
  serializedVersion: 7
  m_Legacy: 0
  m_Compressed: 0
  m_UseHighQualityCurve: 0
  m_RotationCurves:
  - curve:
      serializedVersion: 2
      m_Curve:
      - serializedVersion: 3
        time: 0
        value: {x: 0, y: -0, z: -0, w: 1}
        inSlope: {x: 0, y: 0, z: 0, w: 0}
        outSlope: {x: 0, y: 0, z: 0, w: 0}
        tangentMode: 0
        weightedMode: 0
        inWeight: {x: 0.33333334, y: 0.33333334, z: 0.33333334, w: 0.33333334}
        outWeight: {x: 0.33333334, y: 0.33333334, z: 0.33333334, w: 0.33333334}
....
  m_EditorCurves: []
  m_EulerEditorCurves: []
  m_HasGenericRootTransform: 0
  m_HasMotionFloatCurves: 0
  m_Events: []

アニメーションイベントのyaml定義の確認

まずは、通常通り UnityのAnimation Eventから以下のように Test という Animation Eventを追加します。

この時にテキストエディタ側で確認をすると以下のようになっています。

  m_Events:
  - time: 0
    functionName: Test
    data: 
    objectReferenceParameter: {fileID: 0}
    floatParameter: 0
    intParameter: 0
    messageOptions: 0

任意の時間にイベントを追加

アニメーションのフレームを確認する必要があります。今回はFBXがあるので確認をすると以下のよう30FPSになっています。

1frameは 1/30s ≒ 0.033sです。そのため任意のフレームに入れたい場合は、1/30 × フレーム数になります。

例として10フレーム目と20フレーム目に追加するといます。この場合は以下のように追加をします。

  m_Events:
  - time: 0.33333334
    functionName: Test
    data: 
    objectReferenceParameter: {fileID: 0}
    floatParameter: 0
    intParameter: 0
    messageOptions: 0
  - time: 0.6666667
    functionName: Tetst
    data: 
    objectReferenceParameter: {fileID: 0}
    floatParameter: 0
    intParameter: 0
    messageOptions: 0

上記を追加することで、以下のように指定したフレームにAnimatioin Eventを追加することができます。

参考サイト

blog.unity.com

画像処理・クラスタリングを用いて画像内の色を単色化する

初めに

ある画像から近い色同士で色をまとめてほしい時があります。この際に使用できる画像処理やクラスタリングの手法を試してみました。

使用例として、ゲーム開発における地面のマテリアル(どのような地面の種別なのか)判定として使用できそうです

K-means法を使用したゲームアセットの地面画像の単色化デモ画像

上記は以下のUnity Assetを使用しています。わかりやすくするために地面より高いものは一部非表示にしています。

リポジトリは以下で公開しています

github.com

開発環境

アプローチの方針

今回は以下の手法を用いて画像の単色化を行っていきます。詳しい記事はリンクを貼っているので、そちらを参照してください

  1. k-means法
  2. DBSCAN法
  3. GMM法

ライブラリのインストール

pip install opencv-python-headless==4.9.0.80 numpy==1.26.4 matplotlib==3.9.0 scikit-learn==1.4.2

もしくはリポジトリを使用する場合は以下です

pip install -r requirements.txt

k-means法

今回はK-meansクラスタリングの回数が1回の場合は、求めていたものと違う結果になったため複数回行いその平均値を取るようにしました。また引数から単色する際の色の数を指定できます

以下で実行することができます

python K-means-clustering.py --image_path test.png --num_colors 5 --attempts 10

デモ

実行した際の結果は以下になります * num_colors(単色にする色の数) = 5 * attempts(試行回数) = 10

左が元の画像で、右側が処理後の画像です

コード

以下が実際のコードになります

import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

def apply_kmeans(image, k, attempts):
    # 画像を2次元配列に変換
    data = image.reshape((-1, 3))
    data = np.float32(data)

    # K-meansクラスタリングを複数回実行して最適な結果を選択
    best_labels = None
    best_centers = None
    best_inertia = float('inf')
    for _ in range(attempts):
        kmeans = KMeans(n_clusters=k, init='k-means++', n_init=1, max_iter=300)
        kmeans.fit(data)
        if kmeans.inertia_ < best_inertia:
            best_inertia = kmeans.inertia_
            best_labels = kmeans.labels_
            best_centers = kmeans.cluster_centers_

    return best_labels, best_centers

def main(image_path, k, attempts):
    # 画像の読み込み
    image = cv2.imread(image_path)

    # 画像をBGRからRGBに変換(matplotlibで表示するため)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # クラスタリングの適用
    labels, centers = apply_kmeans(image_rgb, k, attempts)

    # クラスタリングの結果を元の形に戻す
    centers = np.uint8(centers)
    segmented_image = centers[labels.flatten()]
    segmented_image = segmented_image.reshape(image_rgb.shape)

    # 元の画像の色を使用して単色化
    unique_labels = np.unique(labels)
    for label in unique_labels:
        mask = (labels == label).reshape(image_rgb.shape[:2])
        mean_color = np.mean(image_rgb[mask], axis=0)
        segmented_image[mask] = mean_color

    # 画像を表示
    plt.figure(figsize=(8, 4))  # ウィンドウのサイズを変更
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(image_rgb)
    plt.subplot(1, 2, 2)
    plt.title(f'Segmented Image with {k} Colors')
    plt.imshow(segmented_image)
    plt.tight_layout()  # レイアウトを自動調整
    plt.show()

    # 結果の保存
    output_path = 'segmented_image.png'
    cv2.imwrite(output_path, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))

    print(f'Segmented image saved to: {output_path}')

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="K-means clustering color quantization with parameter tuning")
    parser.add_argument('--image_path', type=str, default='test.png', help='Path to the input image')
    parser.add_argument('--num_colors', type=int, default=5, help='Number of colors for quantization')
    parser.add_argument('--attempts', type=int, default=10, help='Number of attempts for K-means clustering')

    args = parser.parse_args()

    main(args.image_path, args.num_colors, args.attempts)

DBSCAN法

処理をする中で画像に対して、データの2次元配列変換・全データポイント間の距離計算・一時的なデータ構造の作成をすることで、それなりのメモリを使用します。そのため、今回は画像のスケールを指定できるようにしています。

以下で実行することができます

python DBSCAN-clustering.py --image_path test.png --eps 10.0 --min_samples 10 --scale_factor 0.1

デモ

実行した際の結果は以下になります * scale_factor(スケール値) = 0.1

左が元の画像で、右側が処理後の画像です

コード

以下が実際のコードになります

import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import DBSCAN


def apply_dbscan(image, eps, min_samples):
    # 画像を2次元配列に変換
    data = image.reshape((-1, 3))
    data = np.float32(data)

    # DBSCANクラスタリングの適用
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(data)
    labels = db.labels_

    # ノイズとして識別されたピクセルに対処
    unique_labels = np.unique(labels)
    centers = []
    for label in unique_labels:
        if label == -1:  # ノイズ
            centers.append([0, 0, 0])  # 黒に設定
        else:
            centers.append(np.mean(data[labels == label], axis=0))

    centers = np.uint8(centers)
    segmented_image = centers[labels]
    segmented_image = segmented_image.reshape(image.shape)

    return segmented_image

def main(image_path, eps, min_samples, scale_factor):
    # 画像の読み込み
    image = cv2.imread(image_path)

    # 画像のサイズを縮小
    height, width = image.shape[:2]
    new_height, new_width = int(height * scale_factor), int(width * scale_factor)
    image_resized = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)

    # 画像をBGRからRGBに変換(matplotlibで表示するため)
    image_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)

    # DBSCANクラスタリングの適用
    segmented_image = apply_dbscan(image_rgb, eps, min_samples)

    # 画像を表示
    plt.figure(figsize=(8, 4))  # ウィンドウのサイズを変更
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(cv2.cvtColor(cv2.resize(image, (new_width, new_height)), cv2.COLOR_BGR2RGB))
    plt.subplot(1, 2, 2)
    plt.title('Segmented Image using DBSCAN')
    plt.imshow(segmented_image)
    plt.tight_layout()  # レイアウトを自動調整
    plt.show()

    # 結果の保存
    output_path = 'segmented_image_dbscan.png'
    cv2.imwrite(output_path, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))

    print(f'Segmented image saved to: {output_path}')

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="DBSCAN clustering for image color quantization")
    parser.add_argument('--image_path', type=str, default='test.png', help='Path to the input image')
    parser.add_argument('--eps', type=float, default=10.0, help='The maximum distance between two samples for one to be considered as in the neighborhood of the other.')
    parser.add_argument('--min_samples', type=int, default=10, help='The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.')
    parser.add_argument('--scale_factor', type=float, default=0.1, help='Factor to scale the image down by.')

    args = parser.parse_args()

    main(args.image_path, args.eps, args.min_samples, args.scale_factor)

GMM法

あまり使用されない手法?らしいですが、念のため試してみます

デモ

実行した際の結果は以下になります * n_components = 5

左が元の画像で、右側が処理後の画像です

コード

以下が実際のコードになります

import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.mixture import GaussianMixture


def apply_gmm(image, n_components):
    # 画像を2次元配列に変換
    data = image.reshape((-1, 3))
    data = np.float32(data)

    # ガウシアン混合モデルの適用
    gmm = GaussianMixture(n_components=n_components).fit(data)
    labels = gmm.predict(data)
    centers = gmm.means_

    centers = np.uint8(centers)
    segmented_image = centers[labels]
    segmented_image = segmented_image.reshape(image.shape)

    return segmented_image

def main(image_path, n_components):
    # 画像の読み込み
    image = cv2.imread(image_path)

    # 画像をBGRからRGBに変換(matplotlibで表示するため)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # GMMクラスタリングの適用
    segmented_image = apply_gmm(image_rgb, n_components)

    # 画像を表示
    plt.figure(figsize=(8, 4))  # ウィンドウのサイズを変更
    plt.subplot(1, 2, 1)
    plt.title('Original Image')
    plt.imshow(image_rgb)
    plt.subplot(1, 2, 2)
    plt.title('Segmented Image using GMM')
    plt.imshow(segmented_image)
    plt.tight_layout()  # レイアウトを自動調整
    plt.show()

    # 結果の保存
    output_path = 'segmented_image_gmm.png'
    cv2.imwrite(output_path, cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR))

    print(f'Segmented image saved to: {output_path}')

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="GMM clustering for image color quantization")
    parser.add_argument('--image_path', type=str, default='test.png', help='Path to the input image')
    parser.add_argument('--n_components', type=int, default=5, help='Number of components for GMM.')

    args = parser.parse_args()

    main(args.image_path, args.n_components)

GCP-GPUでのCould not load library libcudnn_cnn_train.so.8.のエラー対応

初めに

AI周りの学習でtorchを使うことがありますが、cudannのエラーによって学習が始めらない問題にぶつかったので解決方法をメモしておきます

開発環境

  • GCP 

  • torch version : 2.3.0+cu121

  • cuda 12.1
  • Python 3.10
  • torch.backends.cudnn.version() : 8904
nvidia-smi
Sun May 12 08:37:42 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              54W / 400W |      4MiB / 40960MiB |     26%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

エラー詳細

/usr/local/cuda/lib64/libcudnn_cnn_train.so.8: undefined symbol: _ZN5cudnn3cnn34layerNormFwd_execute_internal_implERKNS_7backend11VariantPackEP11CUstream_stRNS0_18LayerNormFwdParamsERKNS1_20NormForwardOperationEmb, version libcudnn_cnn_infer.so.8

解決方法

ローカルのcudaのlibraryを削除します

cd /usr/local/cuda-12.1/lib64
sudo rm -f libcudnn*
cd /usr/local/cuda-12.1/include
sudo rm -f cudnn*

次にcudannのversionをbashに適応します

# cuda version change
export PATH=/usr/local/cuda-12.2/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-12.2/lib64:$LD_LIBRARY_PATH

source ~/.bashrc

最後に現在の状況を確認します

import torch
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.backends.cudnn.version())

これで以下のようになっていれば問題ないです

Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> print(torch.__version__)
2.3.0+cu121
>>> print(torch.cuda.is_available())
True
>>> print(torch.version.cuda)
12.1
>>> print(torch.backends.cudnn.version())
8902

参考記事

discuss.pytorch.org

指定したディレクトリ内のすべてのwavファイルのパスを再帰的に取得してtxtファイルに保存する

開発環境

詳細

以下のコードで指定したディレクトリ内のwavファイルのパスを一覧にしたテキストファイルが出力されます

# 指定したディレクトリ内のすべてのファイルのパスを再帰的に取得し、txtファイルに保存するスクリプト

import os
import sys

def get_file_paths(directory):
    file_paths = []

    for root, directories, files in os.walk(directory):
        for filename in files:
            if filename.endswith(".wav"):
                filepath = os.path.join(root, filename)
                file_paths.append(filepath)

    return file_paths

def save_file_paths_to_txt(file_paths, output_file):
    with open(output_file, 'w', encoding='utf-8') as file:
        for path in file_paths:
            file.write(path + '\n')

    print(f"File paths saved to {output_file}")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Please provide the directory path as a command-line argument.")
        sys.exit(1)

    directory = sys.argv[1]  # コマンドライン引数から検索対象のディレクトリを取得
    output_file = "file_paths.txt"  # 出力するテキストファイル名を指定

    file_paths = get_file_paths(directory)
    save_file_paths_to_txt(file_paths, output_file)

出力は以下のようになります

jvs\jvs001\falset10\wav24kHz16bit\jvs001_falset10_jvs001_falset10_BASIC5000_0235.wav
jvs\jvs001\falset10\wav24kHz16bit\jvs001_falset10_jvs001_falset10_BASIC5000_0408.wav
jvs\jvs001\falset10\wav24kHz16bit\jvs001_falset10_jvs001_falset10_BASIC5000_1140.wav
jvs\jvs001\falset10\wav24kHz16bit\jvs001_falset10_jvs001_falset10_BASIC5000_1356.wav

TransformersのOptimumを使ってモデルをonnxに変換する

開発環境

  • Ubutntu 22.02

準備

まずは以下のライブラリをインストールします

python -m pip install optimum

モデルの変換

例として、cyberagent/open-calm-smallを変してみます。 変換する際には、以下のコマンドで変換することができます。

optimum-cli export onnx --model cyberagent/open-calm-small open-calm-small.onnx --trust

ただモデルが対応していない場合は、カスタムモデルとして対応する必要があります (調査中)

huggingface.co

変換した後は、以下のように変換後のモデルが出力されます