畳み込みニューラルネットワーク

ディープラーニングのブームは画像解析において、ILSVRC のコンペティションで従来の解析手法よりもディープラーニングを用いたモデルが制度を大きく上回った頃から始まったといわれています。そして、それから現在まで 8 年ほど、画像処理、自然言語処理の領域において目覚ましい進展を遂げていることは事実から明らかです。

本章では、その画像解析において目覚ましい発展を遂げている畳み込みニューラルネットワーク (Convolutional Neural Network; 以下 CNN) の基本的な考えから実装方法まで学んでいきます。

本章の構成

  • フィルタから CNN へ
  • データの準備
  • 特徴抽出
  • 画像の分類問題

フィルタから CNN へ

前章から画像処理ではフィルタを設計し、必要な特徴を抽出するといった操作がメインであることが分かりました。それでは、ここで次の問題として、エッジ検出のフィルタの値はわかったが、犬と猫を判別するためのフィルタの値はいくらでしょうか。また、背景を除去することができるフィルタの値はいくらでしょうか。こう考えると、一見便利そうに見えるフィルタですが、そのフィルタの値を決定することができなければ使うことができず、このフィルタの値を決めることが難しいタスクであることがわかります。そのため、画像処理という学問が体系立ってできたのですが、ブレークスルーを起こすことはありませんでした。しかし、近年画像処理の領域でブレークスルーが起きていると紹介しましたが、なぜでしょうか。

26_1

従来の画像解析が上の流れだとすると、これから紹介する CNN では、この畳み込みと全結合 NN の働きを一体化させ、このプロセスをすべてニューラルモデルの中に包括したものといえます。これにより、これまで経験と勘によって行われていた前工程を自動化することができるようになりました。

26_2

さて、この章の本題なのですが、CNN では従来から使われてきた画像処理のフィルタから着想を得ており、人間と犬と猫を判別できるようなフィルタを経験と勘で求めることが難しいのであれば、これも一種のパラメータとして学習させれば良いと考えたのです。フィルタをかけて新しい画像を生成する処理を畳み込み (Convolution)、画像を縮小させる処理をプーリング (Pooling) と呼びます。新たに変換した画像を生成し、これを 1/2 など縮小していく処理を交互に繰り返していきます。一般に畳み込みの処理ではもともと 3ch3ch であった画像を 64ch,128ch,256ch,64ch, 128ch, 256ch, \cdots とチャネルを増やしていく。そして、画像を縮小していくと、徐々に全結合のニューラルネットワークに形が近づいていく様子が見て取れます。このように、大事な情報を抽出しながら徐々に小さくしていく。これであれば画像特有の上下左右の位置関係も比較的考慮することが可能です。

データの準備

それでは、実際に画像データを使って実装していきます。今回のデータセットは犬と猫ではなく、よく使われる MNIST という 0~9 までの手書き文字データセットを扱います。最初は難易度の低い問題から触っていき、徐々に慣れていきましょう。

画像に関するデータセットやよく使われる処理は torchvision ライブラリで用意されているため、こちらの使い方も覚えていきましょう。

26_3

本章の後半でモデルの定義をする際に、PyTorch Lightning を使うのですが Colab を使用する方は毎回インストールをする必要があります。他のライブラリとの関係を保つために Colab を使用する方は最初にインストールしておきます。

# インストール
!pip install pytorch-lightning
      
import torch, torchvision
      
torch.__version__, torchvision.__version__
      
('1.4.0', '0.5.0')

データセットの準備から行っていきます。PyTorch ではデータを torch.Tensor 型にして扱うことが決まりでした。

画像では torchvision.transforms.Compose の中で、データを読み込んだ後に行う処理を定義します。torch.Tensor 型に変更するためには、ToTensor() を使用します。この他にも画像を正規化したり、画像を水増ししたりする処理を加えることができるのですが、一度このまま進んでいきます。

# データ読み込み時の処理
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor()
])
      

torchvision.datasets にはデフォルトで用意されているデータセットがいくつかあります。そこから今回は MNIST を読み込んでいきます。

# 学習用データセットの読み込み
train = torchvision.datasets.MNIST(
    root='.', 
    train=True, 
    download=True, 
    transform=transform)
      

引数がいくつかありますので、説明します。

  • root:ダウンロードするディレクトリを決定します。'.' とすればカレントディレクトリ(現在のディレクトリ)を示します
  • train:学習用データセットを含むかどうか(False とすれば、テスト用データセットのみが含まれます)
  • download:既に一度ダウンロードしていれば False、はじめてであれば True としてダウンロードを実行します
  • transform:データセットの読み込み時に行う処理を指定します

データセットの中身を確認していきましょう。

# データセットの中身を確認
train
      
Dataset MNIST Number of datapoints: 60000 Root location: . Split: Train StandardTransform Transform: Compose( ToTensor() )
# サンプル数
len(train)
      
60000
# 入力値と目標値がタプルで格納
type(train[0])
      
tuple

タプルで 60,000 個のデータが格納されているため、要素番号を指定することで入力値と目標値をそれぞれ確認することができます。

# 入力値
train[0][0]
      
tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706, 0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922, 0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137, 0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000, 0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275, 0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922, 0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294, 0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098, 0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922, 0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922, 0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706, 0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922, 0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922, 0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])
# 目標値
train[0][1]
      
5

入力が今回は 1 枚の画像データとなっており、それぞれの値は輝度値を表しています。目標値を確認すると 5 となっているので、1 枚目のデータは手書き数字で 5 が書かれているようです。

また、入力値のサイズに注目してみましょう。

# 入力値のサイズ
train[0][0].shape
      
torch.Size([1, 28, 28])

画像が (height, width, channels) の順ではなく、(channels, height, width) の順で格納されていることが分かります。CNN を扱う際には、(height, width, channels) の順に並べていることが多いのですが、torchvision を使うとこういった処理も自動的に内部でよしなに行ってくれるので、実装上は特に気にしなくても大丈夫です。

計算に行く前に試しに 1 枚目をプロットしてみましょう。データの中身を確認することは当然ですが大切です。Matplotlib で可視化するには、(height, width, channels) の順番に変更する必要があります。np.transpose を使って変更しましょう。

import numpy as np
import matplotlib.pyplot as plt
      
# (0:channels, 1:height, 2:width) -> (1:height, 2:width、0:channels)
img = np.transpose(train[0][0], (1, 2, 0))
      
img.shape
      
torch.Size([28, 28, 1])

Matplotlib でグレースケール表示するには、チャネルサイズをなくす必要があるのでさらに (height, width) に変形します。

img = img.reshape(img.shape[0], img.shape[1])
      
img.shape
      
torch.Size([28, 28])
plt.imshow(img, cmap='gray');
      
<Figure size 432x288 with 1 Axes>

こちらが正解ラベル 5 の画像です。このような画像が 60,000 枚用意されているのが MNIST です。

特徴抽出

今回はモデルを組むことを目標とするのではなく、畳み込み、プーリング、全結合層への一連処理の流れ convolution -> pooling -> fc を確認していきます。

テスト用の画像として、先程プロットした MNIST の画像 1 枚目を使用します。

x = train[0][0]
      
x.shape
      
torch.Size([1, 28, 28])

サイズを確認すると、PyTorch で扱うことのできる (channels, height, width) になっているため、そのまま使用します。もしも、今後皆さんがご自身のデータセットを活用する際には、データセットの形式も意識する必要があることを覚えておいてください。

PyTorch では torch.nn 内に Conv2d が準備されています。引数の内容は、以下の通りです。

  • in_channels:入力の channel の数
  • out_channels:出力の channel の数
  • kernel_size:フィルタ(カーネル)のサイズ
  • stride:フィルタを動かす幅 (default=1)
  • padding:画像の外側を囲う数 (padding=0)

それでは実装していきましょう。

import torch.nn as nn
import torch.nn.functional as F
      
# 畳み込み層の定義
conv = nn.Conv2d(
    in_channels=1, 
    out_channels=4,
    kernel_size=3,
    stride=1,
    padding=1)
      

宣言した時点で、全結合層と同じようにフィルタの重みがランダムに割り振られています。中身を確認してみましょう。

conv.weight
      
Parameter containing: tensor([[[[ 0.1453, 0.2541, -0.2919], [-0.1666, 0.1382, -0.0588], [-0.2303, 0.0689, 0.1214]]], [[[ 0.0753, -0.2702, -0.2139], [ 0.1549, 0.0478, 0.0988], [-0.2044, 0.1995, 0.3076]]], [[[ 0.0849, 0.2939, 0.2317], [-0.1568, 0.0250, -0.1974], [-0.1257, -0.2839, 0.3040]]], [[[-0.0348, 0.0871, -0.0512], [ 0.2326, -0.3136, -0.3081], [ 0.1346, 0.1522, -0.0883]]]], requires_grad=True)
conv.weight.shape
      
torch.Size([4, 1, 3, 3])
conv.bias
      
Parameter containing: tensor([ 0.1906, -0.0142, -0.2758, 0.0444], requires_grad=True)
conv.bias.shape
      
torch.Size([4])

フィルタの重みで確認できるサイズは、(出力画像のchannel, 入力画像のchannel, kernel_height, kernel_width) となっています。つまりこちらでは、(1, 3, 3) のフィルタを出力画像数の 4 セット生成されていることになります。

今回はグレースケールの画像に対して畳み込みを行いましたが、もしも channels=3 のカラー画像に対してフィルタサイズ 3 の畳み込みを行う場合は (3, 3, 3) のフィルタが 4 セット生成されることになります。

それでは、定義した conv を入力 x に対して適用しましょう。通常はミニバッチ単位で入力がはいってくるため、サイズは (channels, height, width) ではなく、(batchsize, channels, height, width) となります。今回は batchsize=1 です。

# batchsize=1 となるようにリサイズ
x = x.reshape(1, 1, 28, 28)
      
# 畳み込み
x = conv(x)
x
      
tensor([[[[ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906], [ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906], [ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906], ..., [ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906], [ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906], [ 0.1906, 0.1906, 0.1906, ..., 0.1906, 0.1906, 0.1906]], [[-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142], [-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142], [-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142], ..., [-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142], [-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142], [-0.0142, -0.0142, -0.0142, ..., -0.0142, -0.0142, -0.0142]], [[-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758], [-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758], [-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758], ..., [-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758], [-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758], [-0.2758, -0.2758, -0.2758, ..., -0.2758, -0.2758, -0.2758]], [[ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444], [ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444], [ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444], ..., [ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444], [ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444], [ 0.0444, 0.0444, 0.0444, ..., 0.0444, 0.0444, 0.0444]]]], grad_fn=<MkldnnConvolutionBackward>)
x.shape
      
torch.Size([1, 4, 28, 28])

(batchsize, channels, height, width) なので、channels が 4 に増えていることがわかります。

次に、プーリング処理 (Pooling) を実装してみましょう。プーリングは torch.nn.functional に 用意されており、今回は max_pool2d を使用します。カーネルサイズが (2, 2) のプーリングは画像サイズが半分になります。

# プーリング処理
x = F.max_pool2d(x, kernel_size=2, stride=2)
      
# サイズを確認
x.shape
      
torch.Size([1, 4, 14, 14])

(1, 4, 28, 28)(1, 4, 14, 14) と画像サイズが縮小されていることが分かります。このようにして、画像の特徴を抽出できました。実際にはこの後にも複数の畳み込み演算、プーリング処理を繰り返すことになるのですが、そのまま全結合層と結合して、分類にはいっていきます。

全結合層と結合

畳み込み層で取得した値を全結合層に入力していきます。ただ、ここで問題となるのが、データのサイズです。上記で求めた値は、1 サンプルにつき (4, 14, 14) のテンソルで定義されていますが、全結合層に入力するときは、ベクトルでなければなりません。そこで、「テンソル → ベクトル」 に変換する Flatten と呼ばれる処理が必要になります。

サイズは (4, 14, 14) なので、ベクトル化するには、4×14×14 の値を求めます。

print('channels :', x.shape[1])
print('height :', x.shape[2])
print('width :', x.shape[3])
      
channels : 4 height : 14 width : 14
x_shape = x.shape[1] * x.shape[2] * x.shape[3]
x_shape
      
784

こちらの数値をもとに、(サンプル数, ベクトルの要素数) のサイズに変更すれば、全結合層に入力することができます。サイズ変更をする際には torch.view() 関数を使います。

# 今回はベクトルの要素数が決まっているため、サンプル数は自動で設定
# -1 とするともう片方の要素に合わせて自動的に設定されます
x = x.view(-1, x_shape)
      
x.shape
      
torch.Size([1, 784])

今回は 0~9 までの手書き数字を分類したいので 10 クラス分類です。ノード数を 10 としましょう。

# 全結合層の定義
fc = nn.Linear(x_shape, 10) # 784 => 10
      
# 線形変換
x = fc(x)
      
x
      
tensor([[-0.0336, 0.0582, 0.0231, 0.0931, 0.1196, -0.1519, 0.0765, -0.0218, -0.1022, -0.2128]], grad_fn=<AddmmBackward>)
x.shape
      
torch.Size([1, 10])

このように、畳み込み演算から全結合層での線形変換までの流れを確認できました。

それでは、PyTorch Lightning を使って実際にモデルの定義から学習までを行いましょう。

画像の分類問題

PyTorch Lightning を使って、CNN のネットワークを定義していきましょう。ここまでの流れを一気通貫しておこなっていくだけです。

データセットは引き続き MNIST を使用し、0~9 までの 10 種類の手書き文字を分類する問題に取り組みましょう。

データセットの準備

復習として、もう一度データセットの準備からしていきます。

# データ読み込み時に行う処理
transform = transforms.Compose([
    transforms.ToTensor()
])
      

先程は学習用データセットのみを用意したのですが、今回はテスト用データセットも準備します。train=False とすれば、テスト用データセットのみが指定できます。

# データセットの取得
train_val = torchvision.datasets.MNIST(
    root='.',
    train=True,
    download=True,
    transform=transform)

test = torchvision.datasets.MNIST(
    root='.',
    train=False,
    download=True,
    transform=transform)
      
# train : val = 80% : 20%
n_train = int(len(train_val) * 0.8)
n_val = len(train_val) - n_train
      
# データをランダムに分割
torch.manual_seed(0)

train, val = torch.utils.data.random_split(train_val, [n_train, n_val])
      
# 分割後のサンプル数を確認
len(train), len(val), len(test)
      
(48000, 12000, 10000)

モデルの定義

シンプルなネットワークを定義していきましょう。流れとしては convolution -> pooling -> fc のネットワークにします。

また PyTorch Lightning を使用し、前回同様に TrainNet, ValidationNet, TestNet として分けてクラスを定義します。

import pytorch_lightning as pl
from pytorch_lightning import Trainer
      
class TrainNet(pl.LightningModule):

    @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True)

    def training_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        results = {'loss': loss}
        return results
      
class ValidationNet(pl.LightningModule):

    @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val, self.batch_size)

    def validation_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'val_loss': loss, 'val_acc': acc}
        return results

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        results =  {'val_loss': avg_loss, 'val_acc': avg_acc}
        return results
      
class TestNet(pl.LightningModule):

    @pl.data_loader
    def test_dataloader(self):
        return torch.utils.data.DataLoader(test, self.batch_size)

    def test_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'test_loss': loss, 'test_acc': acc}
        return results

    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        results = {'test_loss': avg_loss, 'test_acc': avg_acc}
        return results
      
class Net(TrainNet, ValidationNet, TestNet):

    def __init__(self, input_size=784, hidden_size=100, output_size=10, batch_size=256):
        super(Net, self).__init__()
        self.batch_size = batch_size
        # 使用する層の宣言
        self.conv = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def lossfun(self, y, t):
        return F.cross_entropy(y, t)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

    def forward(self, x):
        x = self.conv(x)
        x = F.max_pool2d(x, 2, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
      

モデルの学習

今回の学習から GPU を使用していきます。Trainer のインスタンス化の際に、gpus の引数に使用する GPU の数を指定するだけで GPU での演算に切り替えることが可能です。また、GPU を使用する場合には乱数のシードを固定するときにも注意が必要です。これまでと同様に、torch.manual_seed により、CPU と CUDA に対するシードを固定することができますが、cuDNN のシードを固定することができません。そこで、以下のようにシードを固定します。

# cuDNN に対する再現性の確保
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
      

決定的な演算方法にするために上の 2 行を追加しました。計算速度が遅くなることもありますが、再現性の確保のためにここは設定しておきましょう。

# 乱数のシードを固定
torch.manual_seed(0)

# モデルの学習準備
net = Net()

# 単一のGPUで学習
trainer = Trainer(gpus=1)

# モデルの学習
trainer.fit(net)
      
# 検証&テストデータに対する結果
trainer.test()
trainer.callback_metrics
      
Testing: 100%|██████████| 40/40 [00:01<00:00, 33.90batch/s]
{'loss': 0.10484535992145538, 'test_acc': 0.9750000238418579, 'test_loss': 0.084235779941082, 'val_acc': 0.9708040356636047, 'val_loss': 0.09426042437553406}

検証データとテストデータに対しても約 97% の正解率で分類できるモデルを構築できました。古典的な画像処理のフィルタを抑えておくことで CNN も簡単に理解することができます。

学習済みモデルの重みを保存

学習が終わると、学習済みモデルが得られます。学習を行うことの目標は、上手くデータのパターンを見つけて、未知のデータに対しても良い予測精度を出すことです。それが推論と呼ばれます。

PyTorch で準備されている torch.save の関数を使用すれば学習済みモデルの重みを保存できます。保存の際には、torch.save(モデルのパラメータ, モデルの名前) としてモデルの名前と学習済みモデルを指定しましょう。モデルのパラメータは モデル.state_dict() とすると取得できます。

# パラメータの名前
net.state_dict().keys()
      
odict_keys(['conv.weight', 'conv.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

.keys() を外せば、パラメータの値を一覧表示することができます。それでは保存していきましょう。

# 学習済みモデルの保存
torch.save(net.state_dict(), 'mnist.pt')
      

上のセルを実行後に mnist.pt というファイルができていれば保存完了です。今後は、皆さんが管理しやすい名前で登録してください。

学習済みモデルの重みを使用した推論

先程保存したファイルにはモデルの構造の情報がありません。ですので、学習済みモデルは単にファイルをロードするだけでなく、まずはモデルの構造を明示しておき、そのモデルに対して重みの値を当てはめながらロードしていくことになります。

# モデルの定義
net = Net()
      
# 重みの読み込み
net.load_state_dict(torch.load('mnist.pt'))
      
<All keys matched successfully>

net.load_state_dict() でロードすることを明記し、torch.load() でロードするファイルを指定します。重みの数が合っていなければ失敗してしまいますが、成功すれば <All keys matched successfully> と出力されます。

予測値の計算

本来であれば、学習に使用したデータとは別のデータを使用して予測値を計算していくのですが、試しなのでロードしたモデルが実際に動くかどうかを目標にします。一番最初のサンプルに対する予測値を計算してみましょう。

# 一番最初のサンプル
train[0][0].shape
      
torch.Size([1, 28, 28])

モデルで計算する際には (batchsize, channels, height, width) にする必要があるのでした。次元を増やす方法はいくつかあるのですが、torch.unsqueeze() を使用します。引数に次元を増やしたいインデックスを指定します。今回であれば一番最初に次元を増やしたいので、0 とします。

sample = train[0][0].unsqueeze(0)
sample.shape
      
torch.Size([1, 1, 28, 28])
# 予測値の計算
y_predict = net.forward(sample)
y_predict
      
tensor([[-5.0208, -4.7895, -3.1545, 2.7309, 2.4311, 2.2265, -5.5931, 3.1665, 2.9033, 10.9998]], grad_fn=<AddmmBackward>)

計算できました。学習時であればソフトマックス関数を使って、0~1 の値に変換し損失を計算するのですが、予測ラベルが何かどうかのみが知りたいので torch.argmax() を使って、最も大きな値のインデックスを取り出します。

# 予測ラベル
y_predict.argmax()
      
tensor(9)

予測が正解しているか確認してみましょう。

# 目標値
train[0][1]
      
9

正解していました!

このように学習済みモデルを使用した推論まで実行することができました。基本的にはこの流れをさらに複雑なネットワークを使用したり、データセットを加工したりする処理が入ってくることになるのですが、本章の基本を抑えておけば今後の難解なモノにも柔軟に対応できます。

shareアイコン