EMNISTのデータを学習してONNXを出力する【Python】【ML】

はじめに

UnityのBarracudaでMLを使ったプロジェクトを作成したいので、準備としてデータを作成しています。 MLに関してはちょっと調べてる使っているレベルなので、コード等は間違っていることがあります。 また今回はChatGPTを試しに使ってプログラムを作成しているため、ご了承ください

やりたいこと

  • EMISTのデータをほかのプラットフォーム(今回はUnity)で使えるようにONNXとして出力する

環境

  • Docker version 20.10.22
  • Docker Desktop 4.16.1 (95567) : Windows 10

環境構築

  1. 今回は Docker上にPythonで学習用の環境を構築します。 昔はAnacondaを使用していましたが、環境を共有しやすいように dockerを採用しました。

前提として、DockerとDocker Wesktopがインストールされているものとします。 以下の二つのファイルを用意して、Dockerを起動します

  1. docker compose up -d --build
  2. docker compose exec python3 bash

Dockerfile

FROM python:3.10
USER root

RUN apt-get update
RUN apt-get -y install locales && \
    localedef -f UTF-8 -i ja_JP ja_JP.UTF-8
ENV LANG ja_JP.UTF-8
ENV LANGUAGE ja_JP:ja
ENV LC_ALL ja_JP.UTF-8
ENV TZ JST-9
ENV TERM xterm

RUN pip install --upgrade pip
RUN pip install --upgrade setuptools
RUN python -m pip install torchvision
RUN python -m pip install torch

docker-compose

version: '3'
services:
  python3:
    restart: always
    build: .
    container_name: 'python3'
    working_dir: '/root/'
    tty: true
    volumes:
      - ./opt:/root/opt

実装

MLに関してはよくわからないので、ChatGPTにお願いしていろいろ出してもらいました。 (ちゃんと使ったのは初めてなので、ChatGPTへの返答も記事のメモしてあります)

Dockerの起動が終わったら以下のファイルを opt/ に作成します。

作成したプログラムでデータセットから学習を行い、ONNXを作成します

python opt/xxx.py

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.onnx
import torch.nn.functional as F

# load EMNIST dataset
emnist_data = datasets.EMNIST(
    './EMNIST',
    # split='letters',
    split='balanced',
    train=True, download=True,
    transform=transforms.ToTensor())

data_loader = torch.utils.data.DataLoader(emnist_data, batch_size=2, shuffle=True)

print("EMNIST dataset loaded")


# create and train model
# create and train model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 47)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 128 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.NLLLoss()

print("Model created")

for epoch in range(10):
    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
#
# convert to ONNX format
print("Converting to ONNX format")
torch.onnx.export(model, torch.randn(1, 1, 28, 28), "model.onnx")

# split train and validation data
train_size = int(len(emnist_data) * 0.8)
val_size = len(emnist_data) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(emnist_data, [train_size, val_size])

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True)

# Evaluate on validation set
val_loss = 0
correct = 0
with torch.no_grad():
    for data, target in val_loader:
        output = model(data)
        val_loss += criterion(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

val_loss /= len(val_dataset)
val_acc = correct / len(val_dataset)

print('Validation set: Average loss: {:.4f}, Accuracy: {:.2f}%'.format(val_loss, val_acc * 100))

結果

実行結果(処理の時間と精度)は以下です

また生成したONNXがdocker環境内になるので、ローカルPC側にコピーする際に以下で任意のパスにコピーしてきます

docker cp docker-container-id:file-path local-path

Unity側にONNXをコピーすれば、Unityで Barracudaなどを使って開発を行うことができます。

参考サイトおよびChatGPTの返答内容

データセット引用元

作成者: Gregory Cohen, Saeed Afshar, Jonathan Tapson, and André van Schaik タイトル: EMNIST: an extension of MNIST to handwritten letters 公開日: 2017年2月17日(初版)/2017年5月1日(第2版) URL: http://arxiv.org/abs/1702.05373

参考サイト

atmarkit.itmedia.co.jp

www.nist.gov

qiita.com

zenn.dev

ChatGPT

EMNISTを学習してONNXに出力するプログラムを作成するときにChatGPTに聞いた内容