畳み込みニューラルネットワークの精度向上

本章では、モデルの精度向上に役立つチューニングのテクニックをいくつか紹介していきます。前章までは基礎編として使い方を重視して説明してきました。使い方を理解した次のステップとして、本番環境でも使用できるレベルの予測精度にどのように近づけることができるのか学んでいきましょう。

よく使われるテクニックをいくつか実際に実装しながら、適用前と適用後の予測精度を比較します。

本章の構成

  • ベースモデルの作成
  • 最適化アルゴリズム
  • 過学習対策
  • 活性化関数

ベースモデルを作成

はじめに、ベースモデルを作成しましょう。今後色々なテクニックを適用する際に、適用前と適用後の差分を正確に測るためです。それぞれのテクニックを適用する場合には、追加部分以外はベースモデルと同じモデル構造にします。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
      
TensorFlow 2.x selected. 2.1.0

データセットの準備

前章で扱った手書き文字である MNIST の分類は簡単なモデルの定義でもある程度の正解率が得られますが、もう少し難しい問題設定で試してみましょう。CIFAR10 と呼ばれる以下のような 10 クラスの分類を行います。CIFAR10 は MNIST のグレースケール画像とは異なり、フルカラー画像です。CIFAR10 も MNIST と同様に、TensorFlow の datasets にデータセットが用意されています。

35_1

(x_train, t_train), (x_test, t_test) = tf.keras.datasets.cifar10.load_data()
      
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 6s 0us/step

それでは、今回扱うデータを 25 枚ランダムに抜粋して表示します。

正解ラベル 種別
0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

10 クラス分類となっており、上記の表の種別を分類することが目標です。32×3232 \times 32 と低解像度なところも、CIFAR10 がよく画像の練習問題として扱われる理由のひとつです。

plt.figure(figsize=(12,12))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(x_train[i])
      
<Figure size 864x864 with 25 Axes>
# 正規化
x_train = x_train / 255.0
x_test = x_test / 255.0
      
x_train.shape, x_test.shape, t_train.shape, t_test.shape
      
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

モデルの定義と学習

import os
import random

def reset_seed(seed=0):

    os.environ['PYTHONHASHSEED'] = '0'
    random.seed(seed) # random関数のシードを固定
    np.random.seed(seed) #numpyのシードを固定
    tf.random.set_seed(seed) #tensorflowのシードを固定
      
     from tensorflow.keras import models, layers
      
# シードの固定
reset_seed(0)

# モデルの構築
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
]) 

# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])

model.summary()
      
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 32, 32, 32) 896 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 16, 16, 32) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 16, 16, 64) 18496 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 8, 8, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 8, 8, 128) 73856 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 4, 4, 128) 0 _________________________________________________________________ flatten (Flatten) (None, 2048) 0 _________________________________________________________________ dense (Dense) (None, 128) 262272 _________________________________________________________________ dense_1 (Dense) (None, 10) 1290 ================================================================= Total params: 356,810 Trainable params: 356,810 Non-trainable params: 0 _________________________________________________________________
# 学習の詳細設定
batch_size = 1024
epochs = 50

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test))
      
Train on 50000 samples, validate on 10000 samples Epoch 1/50 50000/50000 [==============================] - 11s 216us/sample - loss: 1.9281 - accuracy: 0.3015 - val_loss: 1.6494 - val_accuracy: 0.4103 Epoch 2/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.5383 - accuracy: 0.4497 - val_loss: 1.4247 - val_accuracy: 0.4882 Epoch 3/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.3910 - accuracy: 0.5025 - val_loss: 1.3631 - val_accuracy: 0.5176 Epoch 4/50 50000/50000 [==============================] - 2s 43us/sample - loss: 1.2837 - accuracy: 0.5468 - val_loss: 1.2566 - val_accuracy: 0.5512 Epoch 5/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.1994 - accuracy: 0.5785 - val_loss: 1.1757 - val_accuracy: 0.5865 Epoch 6/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.1309 - accuracy: 0.6058 - val_loss: 1.1050 - val_accuracy: 0.6091 Epoch 7/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.0634 - accuracy: 0.6313 - val_loss: 1.0750 - val_accuracy: 0.6309 Epoch 8/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.0310 - accuracy: 0.6423 - val_loss: 1.0350 - val_accuracy: 0.6380 Epoch 9/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.9716 - accuracy: 0.6641 - val_loss: 1.0366 - val_accuracy: 0.6440 Epoch 10/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.9482 - accuracy: 0.6734 - val_loss: 0.9739 - val_accuracy: 0.6636 Epoch 11/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9145 - accuracy: 0.6839 - val_loss: 0.9492 - val_accuracy: 0.6770 Epoch 12/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.8768 - accuracy: 0.6990 - val_loss: 0.9495 - val_accuracy: 0.6738 Epoch 13/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.8727 - accuracy: 0.7002 - val_loss: 0.9530 - val_accuracy: 0.6742 Epoch 14/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.8339 - accuracy: 0.7140 - val_loss: 0.9191 - val_accuracy: 0.6796 Epoch 15/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.8073 - accuracy: 0.7224 - val_loss: 0.8858 - val_accuracy: 0.6960 Epoch 16/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7845 - accuracy: 0.7311 - val_loss: 0.8875 - val_accuracy: 0.6934 Epoch 17/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7600 - accuracy: 0.7409 - val_loss: 0.9158 - val_accuracy: 0.6863 Epoch 18/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7558 - accuracy: 0.7422 - val_loss: 0.9166 - val_accuracy: 0.6898 Epoch 19/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7434 - accuracy: 0.7464 - val_loss: 0.8764 - val_accuracy: 0.6988 Epoch 20/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7123 - accuracy: 0.7562 - val_loss: 0.8426 - val_accuracy: 0.7116 Epoch 21/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.6869 - accuracy: 0.7645 - val_loss: 0.8359 - val_accuracy: 0.7132 Epoch 22/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.6705 - accuracy: 0.7725 - val_loss: 0.8462 - val_accuracy: 0.7078 Epoch 23/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6575 - accuracy: 0.7747 - val_loss: 0.8244 - val_accuracy: 0.7160 Epoch 24/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.6343 - accuracy: 0.7834 - val_loss: 0.8377 - val_accuracy: 0.7119 Epoch 25/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6211 - accuracy: 0.7884 - val_loss: 0.8430 - val_accuracy: 0.7145 Epoch 26/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6017 - accuracy: 0.7948 - val_loss: 0.8247 - val_accuracy: 0.7212 Epoch 27/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.5866 - accuracy: 0.8002 - val_loss: 0.8424 - val_accuracy: 0.7198 Epoch 28/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5759 - accuracy: 0.8025 - val_loss: 0.8393 - val_accuracy: 0.7184 Epoch 29/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5618 - accuracy: 0.8076 - val_loss: 0.8364 - val_accuracy: 0.7216 Epoch 30/50 50000/50000 [==============================] - 2s 47us/sample - loss: 0.5533 - accuracy: 0.8104 - val_loss: 0.8419 - val_accuracy: 0.7214 Epoch 31/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5229 - accuracy: 0.8216 - val_loss: 0.8396 - val_accuracy: 0.7280 Epoch 32/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5011 - accuracy: 0.8307 - val_loss: 0.8325 - val_accuracy: 0.7273 Epoch 33/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5007 - accuracy: 0.8292 - val_loss: 0.8373 - val_accuracy: 0.7254 Epoch 34/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4830 - accuracy: 0.8358 - val_loss: 0.8491 - val_accuracy: 0.7297 Epoch 35/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.4749 - accuracy: 0.8384 - val_loss: 0.8430 - val_accuracy: 0.7268 Epoch 36/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4576 - accuracy: 0.8452 - val_loss: 0.8578 - val_accuracy: 0.7266 Epoch 37/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4393 - accuracy: 0.8524 - val_loss: 0.8503 - val_accuracy: 0.7318 Epoch 38/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4316 - accuracy: 0.8523 - val_loss: 0.8814 - val_accuracy: 0.7242 Epoch 39/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4199 - accuracy: 0.8581 - val_loss: 0.8683 - val_accuracy: 0.7286 Epoch 40/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.4064 - accuracy: 0.8637 - val_loss: 0.9345 - val_accuracy: 0.7116 Epoch 41/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.3811 - accuracy: 0.8730 - val_loss: 0.9007 - val_accuracy: 0.7256 Epoch 42/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.3697 - accuracy: 0.8753 - val_loss: 0.8940 - val_accuracy: 0.7302 Epoch 43/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.3602 - accuracy: 0.8806 - val_loss: 0.9283 - val_accuracy: 0.7190 Epoch 44/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.3472 - accuracy: 0.8833 - val_loss: 0.9180 - val_accuracy: 0.7229 Epoch 45/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.3282 - accuracy: 0.8914 - val_loss: 0.9188 - val_accuracy: 0.7310 Epoch 46/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.3151 - accuracy: 0.8957 - val_loss: 0.9375 - val_accuracy: 0.7225 Epoch 47/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.3076 - accuracy: 0.8979 - val_loss: 0.9473 - val_accuracy: 0.7260 Epoch 48/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.2949 - accuracy: 0.9024 - val_loss: 0.9712 - val_accuracy: 0.7254 Epoch 49/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.2923 - accuracy: 0.9029 - val_loss: 0.9823 - val_accuracy: 0.7306 Epoch 50/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.2739 - accuracy: 0.9113 - val_loss: 1.0067 - val_accuracy: 0.7262

結果の確認

results = pd.DataFrame(history.history)
      
results[['accuracy', 'val_accuracy']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe304211d0>
<Figure size 432x288 with 1 Axes>
results[["loss", "val_loss"]].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe301abf28>
<Figure size 432x288 with 1 Axes>
results.tail(1)
      
loss accuracy val_loss val_accuracy
49 0.273878 0.9113 1.006714 0.7262
Train Test
Base Accuracy 0.911 0.726
Base Loss 0.274 1.007

上記のスコアをベースラインとして、様々なテクニックを適用してどう変化するか確認していきましょう。

最適化アルゴリズム

Basic of Deep Learning の章で、勾配降下法を紹介しました。勾配降下法をベースとして SGD などのさまざまな最適化アルゴリズムがあります。どのアルゴリズムが一番良いのかについてはその時のデータやモデルによって様々であり、色々と試していく必要があるのが現状で、これを使えば間違いないというアルゴリズムはまだ存在しません。今でも活発に世界中のトップリサーチャーによって研究されています。

その中でも現在、代表的な最適化アルゴリズムとしては下記が挙げられます。

  • SGD
  • Momentum SGD
  • RMSprop
  • Adam

他にも存在しますが、まずはこちらを把握しておけば問題ありません。最適化アルゴリズムを使った精度向上のポイントとしては、大きく 2 点あります。

  • 最適化アルゴリズムの選択
  • ハイパーパラメータのチューニング

まずはどの最適化アルゴリズムを選択するかですが、こちらは決まりはありません。Momentum SGD と Adam が初手でよく使われています。

また、各アルゴリズムのハイパーパラメータチューニングも重要なポイントです。一概にどの値にすれば良いのかというのは難しいですが、学習率 (Learning Rate) を 1e-11e-5 辺りの値にすることが経験上多いです。手動でチューニングすることも多いですが、Optuna などの探索ツールを使用して、効率的に探していくこともありますのでどちらも頭の中の候補にいれておきましょう。

それでは、上記で紹介した 4 つの最適化アルゴリズムについて、数式を交えて簡単にご紹介します。数式は覚える必要ありませんので、気になる方だけご覧ください。ここでは勾配を \nabla(ナブラ)という記号で表現します。

SGD (Stochastic Gradient Descent: 確率的勾配降下法)

SGD (Stochastic Gradient Descent: 確率的勾配降下法) は、勾配降下法をミニバッチ学習(オンライン学習)で行ったものです。勾配降下法では、局所最適解への収束が起こりやすいという問題点をランダムにサンプルを選び出すことで解消しました。勾配を求め、重み w\mathbf{w} を勾配と逆方向に更新します。更新する際の更新幅の調整のために、学習係数 η\eta を設けています。

wwηL\begin{array}{c} \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla \mathcal L \end{array}

TensorFlow では、tf.keras.optimizers.SGD(lr=1e-2) として設定できます。lr が学習係数です。

Momentum SGD

vαvηLww+v\begin{aligned} v &\leftarrow \alpha v - \eta \nabla \mathcal L \\ \mathbf{w} &\leftarrow \mathbf{w} + v \end{aligned}

Momentum SGD は SGD に慣性項(vv)を付け足したものです。1 式を 2 式に代入すると wwηL+αv\mathbf{w} \leftarrow \mathbf{w} - \eta \nabla \mathcal L + \alpha v となり、前半部分は SGD と同様であることがわかります。

また、ハイパーパラメータとして α\alpha が追加されており、実装では momentum 引数として指定します。TensorFlow では、tf.keras.optimizers.SGD(lr=1e-2, momentum=0.9) のように設定できます。

このアルゴリズムはハイパーパラメータが 2 つに増えており、最適化が難しいという問題があります。しかし、初手で使うことが多い最適化アルゴリズムなので、皆さんも迷われたらこちらを使ってみてはいかがでしょうか。

RMSprop

rαr+(1α)LLΔwηr+ϵLww+Δw\begin{aligned} r &\leftarrow \alpha r + (1 - \alpha) \nabla \mathcal L \odot \nabla \mathcal L \\\\ \Delta \mathbf{w} &\leftarrow - \frac{\eta}{\sqrt r + \epsilon} \odot \nabla \mathcal L \\\\ \mathbf{w} &\leftarrow \mathbf{w} + \Delta \mathbf{w} \end{aligned}

RMSprop は、SGD における学習係数の箇所を学習の収束に合わせて(勾配の大きさに合わせて)変化するように組まれたアルゴリズムです。勾配の 2 乗の指数移動平均を取るように設計されています。

2 式を 3 式に代入すると、

wwηr+ϵL\begin{array}{c} \mathbf{w} \leftarrow \mathbf{w} - \frac{\eta}{\sqrt{r} + \epsilon} \odot \nabla \mathcal L \end{array}

となり、SGDでの学習係数 η\eta の箇所が ηr+ϵ\frac{\eta}{\sqrt{r} + \epsilon} に変わっていることがわかります。

TensorFlow では tf.keras.optimizers.RMSprop(lr=1e-3, rho=0.9, epsilon=1e-8) として設定でき、それぞれの引数は

  • η\eta : lr
  • α\alpha : rho
  • ϵ\epsilon : eps

に対応しています。こちらもハイパーパラメータの数が増え、最適化が困難なアルゴリズムの一つです。

Adam (Adaptive moment estimation)

t=t1sβ1s+(1β1)Lrβ2r+(1β2)LLs^s1β1tr^r1β2tΔwηr^+ϵs^ww+Δw\begin{aligned} t &= t - 1 \\\\ s &\leftarrow \beta_{1} s + ( 1 - \beta_{1} ) \nabla \mathcal L \\\\ r &\leftarrow \beta_{2} r + ( 1 - \beta_{2}) \nabla \mathcal L \odot \nabla \mathcal L \\\\ \hat{s} &\leftarrow \frac{s}{1 - \beta_{1}^t} \\\\ \hat{r} &\leftarrow \frac{r}{1 - \beta_{2}^t} \\\\ \Delta \mathbf{w} &\leftarrow - \frac{\eta}{\sqrt{\hat{r}} + \epsilon} \hat{s} \\\\ \mathbf{w} &\leftarrow \mathbf{w} + \Delta \mathbf{w} \end{aligned}

Adam は 前述の Momentum の慣性的な動きと、RMSprop の適応的に学習係数を調整する考えを組み合わせたアルゴリズムです。現在、最も評価されているアルゴリズムのひとつですが Adam はハイパーパラメータの数が非常に多いので、それぞれのハイパーパラメータを適切にチューニングすることがポイントになります。

TensorFlow では、 tf.keras.optimizers.Adam(lr=1e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-8) として設定でき、それぞれの引数は

  • η\eta : lr
  • β1\beta_{1} : beta_1
  • β2\beta_{2} : beta_2
  • ϵ\epsilon : epsilon

に対応しています。デフォルトでも良い結果が出やすいアルゴリズムですが、Optuna 等を使用してチューニングすることもおすすめします。

過学習対策

機械学習をおこなう上で、よく遭遇する現象として過学習があります。学習データに対して過度に適合してしまい本来の目的である未知のデータに対する誤差(汎化誤差)が大きくなってしまうことを指します。そのような過学習に対する対策として、いくつかの方法をご紹介していきます。

過学習を防止するための最良の解決策は、より多くの学習用データを使うことです。多くのデータで学習を行えば行うほど、モデルは自然により汎化していく様になります。これが不可能な場合、次善の策は本節で紹介するようなテクニックを使うことです。

過学習への対策として、本章では 4 つのテクニックをご紹介します。

  • ドロップアウト (Dropout)
  • 正則化(Regularization)
  • 早期終了 (Early Stopping)
  • バッチノーマリゼーション (Batch Normalization)

ドロップアウト

35_2

出典:http://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf

ドロップアウトでは、ニューラルネットワークを学習する際に、ある更新で層の中のノードのうちのいくつかを無効にして学習を行い、次の更新では別のノードを無効にして学習をおこなうことを繰り返していきます。これにより学習時にネットワークの自由度を強制的に小さくして汎化性能をあげ、過学習を避けることができます。

ドロップアウトが高性能である理由は、アンサンブル学習を近似しているからと言われています。アンサンブル学習とは、複数のモデルに学習させて、予測結果を統合することで汎化性能を高める手法です。

tensorflow.keras.layers.Dropout で用意されており、引数には入力ユニットをドロップする割合を指定します。もし、Dropout(0.5) とするならば半分のノードを無効化して学習するということを意味しています。

# シードの固定
reset_seed(0)

# モデルのインスタンス化
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])
      
# 学習の詳細設定
batch_size = 1024
epochs = 50

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test))
      
Train on 50000 samples, validate on 10000 samples Epoch 1/50 50000/50000 [==============================] - 3s 55us/sample - loss: 1.9445 - accuracy: 0.2882 - val_loss: 1.6409 - val_accuracy: 0.4091 Epoch 2/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.5640 - accuracy: 0.4339 - val_loss: 1.4259 - val_accuracy: 0.4880 Epoch 3/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.4134 - accuracy: 0.4926 - val_loss: 1.3415 - val_accuracy: 0.5247 Epoch 4/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.3098 - accuracy: 0.5341 - val_loss: 1.2024 - val_accuracy: 0.5746 Epoch 5/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2294 - accuracy: 0.5642 - val_loss: 1.1324 - val_accuracy: 0.5942 Epoch 6/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1704 - accuracy: 0.5865 - val_loss: 1.0809 - val_accuracy: 0.6171 Epoch 7/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1135 - accuracy: 0.6049 - val_loss: 1.0309 - val_accuracy: 0.6378 Epoch 8/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0752 - accuracy: 0.6215 - val_loss: 0.9815 - val_accuracy: 0.6627 Epoch 9/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0263 - accuracy: 0.6421 - val_loss: 0.9950 - val_accuracy: 0.6521 Epoch 10/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.9978 - accuracy: 0.6480 - val_loss: 0.9362 - val_accuracy: 0.6737 Epoch 11/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.9605 - accuracy: 0.6643 - val_loss: 0.9166 - val_accuracy: 0.6802 Epoch 12/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.9215 - accuracy: 0.6775 - val_loss: 0.8844 - val_accuracy: 0.6890 Epoch 13/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.9061 - accuracy: 0.6847 - val_loss: 0.8900 - val_accuracy: 0.6866 Epoch 14/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.8929 - accuracy: 0.6877 - val_loss: 0.8679 - val_accuracy: 0.6972 Epoch 15/50 50000/50000 [==============================] - 2s 47us/sample - loss: 0.8576 - accuracy: 0.7009 - val_loss: 0.8285 - val_accuracy: 0.7132 Epoch 16/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.8410 - accuracy: 0.7056 - val_loss: 0.8449 - val_accuracy: 0.7058 Epoch 17/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.8287 - accuracy: 0.7099 - val_loss: 0.8283 - val_accuracy: 0.7163 Epoch 18/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.8095 - accuracy: 0.7197 - val_loss: 0.8101 - val_accuracy: 0.7167 Epoch 19/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7936 - accuracy: 0.7225 - val_loss: 0.8095 - val_accuracy: 0.7158 Epoch 20/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7821 - accuracy: 0.7249 - val_loss: 0.7828 - val_accuracy: 0.7287 Epoch 21/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7570 - accuracy: 0.7343 - val_loss: 0.7676 - val_accuracy: 0.7361 Epoch 22/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7530 - accuracy: 0.7362 - val_loss: 0.7566 - val_accuracy: 0.7400 Epoch 23/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7359 - accuracy: 0.7427 - val_loss: 0.7447 - val_accuracy: 0.7400 Epoch 24/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7184 - accuracy: 0.7500 - val_loss: 0.7661 - val_accuracy: 0.7344 Epoch 25/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7123 - accuracy: 0.7514 - val_loss: 0.7665 - val_accuracy: 0.7379 Epoch 26/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.7049 - accuracy: 0.7527 - val_loss: 0.7560 - val_accuracy: 0.7373 Epoch 27/50 50000/50000 [==============================] - 2s 47us/sample - loss: 0.6921 - accuracy: 0.7577 - val_loss: 0.7215 - val_accuracy: 0.7510 Epoch 28/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6675 - accuracy: 0.7668 - val_loss: 0.7128 - val_accuracy: 0.7551 Epoch 29/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6602 - accuracy: 0.7672 - val_loss: 0.7287 - val_accuracy: 0.7513 Epoch 30/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6484 - accuracy: 0.7737 - val_loss: 0.7111 - val_accuracy: 0.7596 Epoch 31/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6415 - accuracy: 0.7747 - val_loss: 0.7024 - val_accuracy: 0.7582 Epoch 32/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.6203 - accuracy: 0.7815 - val_loss: 0.7069 - val_accuracy: 0.7572 Epoch 33/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.6131 - accuracy: 0.7841 - val_loss: 0.6937 - val_accuracy: 0.7595 Epoch 34/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.6031 - accuracy: 0.7886 - val_loss: 0.6833 - val_accuracy: 0.7687 Epoch 35/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5930 - accuracy: 0.7929 - val_loss: 0.6763 - val_accuracy: 0.7677 Epoch 36/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5828 - accuracy: 0.7959 - val_loss: 0.6868 - val_accuracy: 0.7634 Epoch 37/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.5815 - accuracy: 0.7944 - val_loss: 0.6783 - val_accuracy: 0.7720 Epoch 38/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5713 - accuracy: 0.7990 - val_loss: 0.6674 - val_accuracy: 0.7725 Epoch 39/50 50000/50000 [==============================] - 2s 47us/sample - loss: 0.5635 - accuracy: 0.8019 - val_loss: 0.6718 - val_accuracy: 0.7682 Epoch 40/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.5543 - accuracy: 0.8043 - val_loss: 0.6684 - val_accuracy: 0.7724 Epoch 41/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5426 - accuracy: 0.8080 - val_loss: 0.6688 - val_accuracy: 0.7764 Epoch 42/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5493 - accuracy: 0.8076 - val_loss: 0.6535 - val_accuracy: 0.7773 Epoch 43/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5249 - accuracy: 0.8149 - val_loss: 0.6529 - val_accuracy: 0.7814 Epoch 44/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5168 - accuracy: 0.8181 - val_loss: 0.6471 - val_accuracy: 0.7821 Epoch 45/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5143 - accuracy: 0.8182 - val_loss: 0.6467 - val_accuracy: 0.7791 Epoch 46/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5068 - accuracy: 0.8218 - val_loss: 0.6602 - val_accuracy: 0.7782 Epoch 47/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.5008 - accuracy: 0.8235 - val_loss: 0.6537 - val_accuracy: 0.7796 Epoch 48/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4944 - accuracy: 0.8276 - val_loss: 0.6485 - val_accuracy: 0.7810 Epoch 49/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4901 - accuracy: 0.8279 - val_loss: 0.6500 - val_accuracy: 0.7795 Epoch 50/50 50000/50000 [==============================] - 2s 46us/sample - loss: 0.4809 - accuracy: 0.8306 - val_loss: 0.6387 - val_accuracy: 0.7846
results = pd.DataFrame(history.history)
results[['accuracy', 'val_accuracy']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe1a6144e0>
<Figure size 432x288 with 1 Axes>
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe1a5af5f8>
<Figure size 432x288 with 1 Axes>
results.tail(1)
      
loss accuracy val_loss val_accuracy
49 0.480941 0.83056 0.63874 0.7846
Train Test
Base Accuracy 0.911 0.726
Base Loss 0.274 1.007
Dropout Accuracy 0.83 0.785
Dropout Loss 0.481 0.639

学習用データセットの正解率とテスト用データセットの正解率の乖離が小さくなり、過学習が抑えられたことが確認できました。

正則化 (Regularization)

パラメータの値の大きさに対してモデルの複雑さが増すことに対するペナルティを設け、過学習を抑える方法があります。これを正則化と呼びます。機械学習の基礎 : 回帰の章で詳しく説明しているので、そちらも合わせてご覧ください。

正則化には大きく 2 種類あり、

  • L1 正則化 (Lasso):パラメータの絶対値の総和を用い、極端なデータの重みを 0 にする
  • L2 正則化 (Ridge):パラメータの二乗の総和を用い、極端なデータの重みを 0 に近づける

ただし注意点として、過学習をしているからといってむやみに正則化をしすぎると逆に学習不足(アンダーフィッティング)となって精度が落ちることもあるので注意が必要です。学習不足の原因は様々であり、

  • モデルが十分複雑でない
  • 正則化が強すぎる
  • 単に学習時間が短すぎる

といった理由が挙げられます。学習不足は、学習用データの中の関連したパターンを学習しきれていないということを意味します。

TensorFlow では各層の kernel_regularizer 引数で指定できます。

from tensorflow.keras import regularizers
      
# シードの固定
reset_seed(0)

# モデルのインスタンス化
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(1e-2), input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(1e-2)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), activation='relu', padding='same', kernel_regularizer=regularizers.l2(1e-2)),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])
      
# 学習の詳細設定
batch_size = 1024
epochs = 50

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test))
      
Train on 50000 samples, validate on 10000 samples Epoch 1/50 50000/50000 [==============================] - 3s 60us/sample - loss: 2.6566 - accuracy: 0.2608 - val_loss: 2.0875 - val_accuracy: 0.3737 Epoch 2/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.9356 - accuracy: 0.4017 - val_loss: 1.8152 - val_accuracy: 0.4303 Epoch 3/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.7839 - accuracy: 0.4389 - val_loss: 1.7835 - val_accuracy: 0.4438 Epoch 4/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.6825 - accuracy: 0.4708 - val_loss: 1.6470 - val_accuracy: 0.4797 Epoch 5/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.6240 - accuracy: 0.4900 - val_loss: 1.5997 - val_accuracy: 0.4959 Epoch 6/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.5665 - accuracy: 0.5104 - val_loss: 1.5426 - val_accuracy: 0.5159 Epoch 7/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.5120 - accuracy: 0.5286 - val_loss: 1.4953 - val_accuracy: 0.5375 Epoch 8/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.4798 - accuracy: 0.5401 - val_loss: 1.4567 - val_accuracy: 0.5472 Epoch 9/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.4455 - accuracy: 0.5535 - val_loss: 1.4271 - val_accuracy: 0.5541 Epoch 10/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.4179 - accuracy: 0.5642 - val_loss: 1.4005 - val_accuracy: 0.5644 Epoch 11/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.3829 - accuracy: 0.5748 - val_loss: 1.3758 - val_accuracy: 0.5763 Epoch 12/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.3500 - accuracy: 0.5912 - val_loss: 1.3603 - val_accuracy: 0.5800 Epoch 13/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.3324 - accuracy: 0.5961 - val_loss: 1.3630 - val_accuracy: 0.5793 Epoch 14/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.3140 - accuracy: 0.6011 - val_loss: 1.3282 - val_accuracy: 0.5913 Epoch 15/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2985 - accuracy: 0.6091 - val_loss: 1.3029 - val_accuracy: 0.6008 Epoch 16/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2884 - accuracy: 0.6124 - val_loss: 1.2946 - val_accuracy: 0.6074 Epoch 17/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.2627 - accuracy: 0.6219 - val_loss: 1.2679 - val_accuracy: 0.6188 Epoch 18/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2469 - accuracy: 0.6295 - val_loss: 1.2601 - val_accuracy: 0.6205 Epoch 19/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2316 - accuracy: 0.6320 - val_loss: 1.2612 - val_accuracy: 0.6206 Epoch 20/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.2212 - accuracy: 0.6357 - val_loss: 1.2442 - val_accuracy: 0.6307 Epoch 21/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.2003 - accuracy: 0.6469 - val_loss: 1.2219 - val_accuracy: 0.6401 Epoch 22/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1885 - accuracy: 0.6508 - val_loss: 1.1927 - val_accuracy: 0.6462 Epoch 23/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1742 - accuracy: 0.6557 - val_loss: 1.1886 - val_accuracy: 0.6512 Epoch 24/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1554 - accuracy: 0.6634 - val_loss: 1.2149 - val_accuracy: 0.6421 Epoch 25/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.1595 - accuracy: 0.6614 - val_loss: 1.1862 - val_accuracy: 0.6488 Epoch 26/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.1370 - accuracy: 0.6699 - val_loss: 1.1616 - val_accuracy: 0.6592 Epoch 27/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1296 - accuracy: 0.6709 - val_loss: 1.1586 - val_accuracy: 0.6663 Epoch 28/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.1137 - accuracy: 0.6768 - val_loss: 1.1496 - val_accuracy: 0.6667 Epoch 29/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.1072 - accuracy: 0.6806 - val_loss: 1.1496 - val_accuracy: 0.6617 Epoch 30/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.1057 - accuracy: 0.6785 - val_loss: 1.1507 - val_accuracy: 0.6658 Epoch 31/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0885 - accuracy: 0.6852 - val_loss: 1.1505 - val_accuracy: 0.6646 Epoch 32/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0856 - accuracy: 0.6878 - val_loss: 1.1359 - val_accuracy: 0.6677 Epoch 33/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0869 - accuracy: 0.6858 - val_loss: 1.1172 - val_accuracy: 0.6804 Epoch 34/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0695 - accuracy: 0.6932 - val_loss: 1.1445 - val_accuracy: 0.6646 Epoch 35/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0577 - accuracy: 0.6990 - val_loss: 1.1316 - val_accuracy: 0.6724 Epoch 36/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0585 - accuracy: 0.6960 - val_loss: 1.1095 - val_accuracy: 0.6786 Epoch 37/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0390 - accuracy: 0.7054 - val_loss: 1.0975 - val_accuracy: 0.6832 Epoch 38/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0344 - accuracy: 0.7058 - val_loss: 1.1033 - val_accuracy: 0.6799 Epoch 39/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0312 - accuracy: 0.7083 - val_loss: 1.1006 - val_accuracy: 0.6800 Epoch 40/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0196 - accuracy: 0.7116 - val_loss: 1.0940 - val_accuracy: 0.6822 Epoch 41/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0137 - accuracy: 0.7125 - val_loss: 1.0889 - val_accuracy: 0.6871 Epoch 42/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0052 - accuracy: 0.7160 - val_loss: 1.1088 - val_accuracy: 0.6733 Epoch 43/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.0199 - accuracy: 0.7113 - val_loss: 1.0753 - val_accuracy: 0.6934 Epoch 44/50 50000/50000 [==============================] - 2s 46us/sample - loss: 1.0143 - accuracy: 0.7147 - val_loss: 1.0851 - val_accuracy: 0.6857 Epoch 45/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9905 - accuracy: 0.7221 - val_loss: 1.0613 - val_accuracy: 0.6985 Epoch 46/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9817 - accuracy: 0.7258 - val_loss: 1.0772 - val_accuracy: 0.6904 Epoch 47/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9814 - accuracy: 0.7264 - val_loss: 1.0602 - val_accuracy: 0.6980 Epoch 48/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9772 - accuracy: 0.7297 - val_loss: 1.0823 - val_accuracy: 0.6903 Epoch 49/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9904 - accuracy: 0.7236 - val_loss: 1.0901 - val_accuracy: 0.6911 Epoch 50/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9649 - accuracy: 0.7318 - val_loss: 1.0519 - val_accuracy: 0.6985
results = pd.DataFrame(history.history)
results[['accuracy', 'val_accuracy']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe1a451198>
<Figure size 432x288 with 1 Axes>
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe1a526048>
<Figure size 432x288 with 1 Axes>
results.tail(1)
      
loss accuracy val_loss val_accuracy
49 0.964864 0.7318 1.051897 0.6985
Train Test
Base Accuracy 0.911 0.726
Base Loss 0.274 1.007
Dropout Accuracy 0.83 0.785
Dropout Loss 0.481 0.639
Regularization Accuracy 0.732 0.699
Regularization Loss 0.965 1.052

学習データの正解率と検証データの正解率の乖離が小さくなり、過学習が抑えられたことが確認できました。

早期終了

35_03

早期終了は Early Stopping と英語で呼ばれる事が多いため、本資料では Early Stopping という言葉を用います。
Early Stopping は学習が進まなくなった場合、途中であっても学習を打ち切ることができる機能です。

モデルの学習時に、callbackstf.keras.callbacks.EarlyStopping を追加します。

# シードの固定
reset_seed(0)

# モデルのインスタンス化
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])
    
# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])
      
# 学習の詳細設定
batch_size = 1024
epochs = 50

# Early Stopping
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test),
                    callbacks=[callback])
      
Train on 50000 samples, validate on 10000 samples Epoch 1/50 50000/50000 [==============================] - 3s 56us/sample - loss: 1.9347 - accuracy: 0.2981 - val_loss: 1.6739 - val_accuracy: 0.4000 Epoch 2/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.5554 - accuracy: 0.4419 - val_loss: 1.4808 - val_accuracy: 0.4668 Epoch 3/50 50000/50000 [==============================] - 2s 45us/sample - loss: 1.3964 - accuracy: 0.5022 - val_loss: 1.3633 - val_accuracy: 0.5202 Epoch 4/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.2971 - accuracy: 0.5403 - val_loss: 1.2470 - val_accuracy: 0.5581 Epoch 5/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.2049 - accuracy: 0.5770 - val_loss: 1.1845 - val_accuracy: 0.5826 Epoch 6/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.1334 - accuracy: 0.6069 - val_loss: 1.1040 - val_accuracy: 0.6055 Epoch 7/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.0725 - accuracy: 0.6270 - val_loss: 1.0723 - val_accuracy: 0.6288 Epoch 8/50 50000/50000 [==============================] - 2s 44us/sample - loss: 1.0384 - accuracy: 0.6380 - val_loss: 1.0461 - val_accuracy: 0.6307 Epoch 9/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9789 - accuracy: 0.6610 - val_loss: 1.0689 - val_accuracy: 0.6355 Epoch 10/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9620 - accuracy: 0.6683 - val_loss: 0.9815 - val_accuracy: 0.6595 Epoch 11/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.9192 - accuracy: 0.6818 - val_loss: 0.9582 - val_accuracy: 0.6679 Epoch 12/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.8802 - accuracy: 0.6967 - val_loss: 0.9628 - val_accuracy: 0.6700 Epoch 13/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.8766 - accuracy: 0.6993 - val_loss: 0.9502 - val_accuracy: 0.6730 Epoch 14/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.8444 - accuracy: 0.7113 - val_loss: 0.9266 - val_accuracy: 0.6787 Epoch 15/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.8108 - accuracy: 0.7226 - val_loss: 0.8925 - val_accuracy: 0.6938 Epoch 16/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.7914 - accuracy: 0.7279 - val_loss: 0.9067 - val_accuracy: 0.6907 Epoch 17/50 50000/50000 [==============================] - 2s 45us/sample - loss: 0.7703 - accuracy: 0.7357 - val_loss: 0.8936 - val_accuracy: 0.6952 Epoch 18/50 50000/50000 [==============================] - 2s 44us/sample - loss: 0.7651 - accuracy: 0.7381 - val_loss: 0.9017 - val_accuracy: 0.6877
results = pd.DataFrame(history.history)
results[['accuracy', 'val_accuracy']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe1a0b0f98>
<Figure size 432x288 with 1 Axes>
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe127cd7b8>
<Figure size 432x288 with 1 Axes>
results.tail(1)
      
loss accuracy val_loss val_accuracy
17 0.765085 0.73806 0.901713 0.6877
Train Test
Base Accuracy 0.911 0.726
Base Loss 0.274 1.007
Dropout Accuracy 0.83 0.785
Dropout Loss 0.481 0.639
Regularization Accuracy 0.732 0.699
Regularization Loss 0.965 1.052
Early Stopping Accuracy 0.738 0.688
Early Stopping Loss 0.765 0.902

エポック数は最大 50 回までと設定していましたが、val_loss が下がらなくなるタイミング( 18 エポック目)で学習が打ち切りされました。

注意点として、ある一定期間に値の向上が見られない場合でも、しばらくするとまた向上することがあるということです。効率的に学習するために、大きめのエポック数を準備して学習を実行することもありますが、ケースバイケースという点は抑えておきましょう。

バッチノーマリゼーション

バッチノーマリゼーションは、ミニバッチごとに平均 x¯\bar x と 標準偏差 σ\sigma を求め、

xs=xx¯σx^=αxs+β\begin{aligned} x_s &= \frac{x - \bar x}{\sigma} \\\\ \hat x &= \alpha x_s + \beta \end{aligned}

のように x^\hat x へと各変数ごとに変換を行います。ここで、α\alphaβ\beta はハイパーパラメータであり、単純な正規化のように平均 0、標準偏差 1 とするのではなく、平均 β\beta、標準偏差 α\alpha となるように変換を行います。必ずしも平均 0、標準偏差 1 が良いとは限らないためです。

実装としては、各バッチ毎に平均と標準偏差を定めて標準化を行うといった非常に簡単な手法なのですが、こちらを層に加えることで各変数感のスケールによる差を吸収できます。

それでは、バッチノーマリゼーションがある場合で試してみましょう。

# シードの固定
reset_seed(0)

#モデルのインスタンス化
model = models.Sequential([
    layers.Conv2D(32, (3, 3), padding='same', activation='relu', input_shape=(32, 32, 3)),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), padding='same', activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(128, (3, 3), padding='same', activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D((2, 2)),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=1e-3)

# モデルのコンパイル
model.compile(loss='sparse_categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])
      
# 学習の詳細設定
batch_size = 1024
epochs = 50

# 学習の実行
history = model.fit(x_train, t_train, 
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(x_test, t_test))
      
Train on 50000 samples, validate on 10000 samples Epoch 1/50 50000/50000 [==============================] - 4s 81us/sample - loss: 1.9106 - accuracy: 0.3634 - val_loss: 3.3544 - val_accuracy: 0.1000 Epoch 2/50 50000/50000 [==============================] - 3s 55us/sample - loss: 1.3606 - accuracy: 0.5097 - val_loss: 5.3415 - val_accuracy: 0.1000 Epoch 3/50 50000/50000 [==============================] - 3s 54us/sample - loss: 1.1802 - accuracy: 0.5772 - val_loss: 5.6006 - val_accuracy: 0.1005 Epoch 4/50 50000/50000 [==============================] - 3s 54us/sample - loss: 1.0651 - accuracy: 0.6187 - val_loss: 4.8560 - val_accuracy: 0.1692 Epoch 5/50 50000/50000 [==============================] - 3s 54us/sample - loss: 0.9774 - accuracy: 0.6511 - val_loss: 4.9619 - val_accuracy: 0.2108 Epoch 6/50 50000/50000 [==============================] - 3s 54us/sample - loss: 0.9032 - accuracy: 0.6781 - val_loss: 4.6061 - val_accuracy: 0.1928 Epoch 7/50 50000/50000 [==============================] - 3s 58us/sample - loss: 0.8382 - accuracy: 0.7012 - val_loss: 3.4456 - val_accuracy: 0.2426 Epoch 8/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.7944 - accuracy: 0.7168 - val_loss: 3.0452 - val_accuracy: 0.2907 Epoch 9/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.7536 - accuracy: 0.7331 - val_loss: 2.4256 - val_accuracy: 0.3600 Epoch 10/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.7133 - accuracy: 0.7433 - val_loss: 1.5773 - val_accuracy: 0.5038 Epoch 11/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.6791 - accuracy: 0.7601 - val_loss: 1.1261 - val_accuracy: 0.6129 Epoch 12/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.6498 - accuracy: 0.7709 - val_loss: 0.9490 - val_accuracy: 0.6643 Epoch 13/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.6155 - accuracy: 0.7798 - val_loss: 0.8856 - val_accuracy: 0.6912 Epoch 14/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.5788 - accuracy: 0.7942 - val_loss: 0.7944 - val_accuracy: 0.7310 Epoch 15/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.5523 - accuracy: 0.8038 - val_loss: 0.8068 - val_accuracy: 0.7265 Epoch 16/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.5344 - accuracy: 0.8090 - val_loss: 0.7180 - val_accuracy: 0.7534 Epoch 17/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.5144 - accuracy: 0.8175 - val_loss: 0.7705 - val_accuracy: 0.7382 Epoch 18/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.4852 - accuracy: 0.8269 - val_loss: 0.7543 - val_accuracy: 0.7535 Epoch 19/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.4659 - accuracy: 0.8340 - val_loss: 0.7666 - val_accuracy: 0.7452 Epoch 20/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.4414 - accuracy: 0.8415 - val_loss: 0.7858 - val_accuracy: 0.7496 Epoch 21/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.4257 - accuracy: 0.8481 - val_loss: 0.7907 - val_accuracy: 0.7483 Epoch 22/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.4131 - accuracy: 0.8527 - val_loss: 0.7541 - val_accuracy: 0.7535 Epoch 23/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.3895 - accuracy: 0.8599 - val_loss: 0.7313 - val_accuracy: 0.7626 Epoch 24/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.3773 - accuracy: 0.8644 - val_loss: 0.7209 - val_accuracy: 0.7629 Epoch 25/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.3569 - accuracy: 0.8714 - val_loss: 0.7986 - val_accuracy: 0.7464 Epoch 26/50 50000/50000 [==============================] - 3s 54us/sample - loss: 0.3522 - accuracy: 0.8730 - val_loss: 0.8473 - val_accuracy: 0.7299 Epoch 27/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.3363 - accuracy: 0.8803 - val_loss: 0.7806 - val_accuracy: 0.7540 Epoch 28/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.3174 - accuracy: 0.8858 - val_loss: 0.7401 - val_accuracy: 0.7717 Epoch 29/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.3085 - accuracy: 0.8864 - val_loss: 0.8967 - val_accuracy: 0.7261 Epoch 30/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.3042 - accuracy: 0.8897 - val_loss: 0.7645 - val_accuracy: 0.7615 Epoch 31/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2862 - accuracy: 0.8958 - val_loss: 0.8164 - val_accuracy: 0.7610 Epoch 32/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2786 - accuracy: 0.8990 - val_loss: 0.8001 - val_accuracy: 0.7618 Epoch 33/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2692 - accuracy: 0.9033 - val_loss: 0.7933 - val_accuracy: 0.7652 Epoch 34/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.2585 - accuracy: 0.9064 - val_loss: 0.8336 - val_accuracy: 0.7513 Epoch 35/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2552 - accuracy: 0.9076 - val_loss: 0.7700 - val_accuracy: 0.7778 Epoch 36/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2454 - accuracy: 0.9114 - val_loss: 0.7553 - val_accuracy: 0.7797 Epoch 37/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2413 - accuracy: 0.9129 - val_loss: 0.8006 - val_accuracy: 0.7759 Epoch 38/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2379 - accuracy: 0.9125 - val_loss: 0.7696 - val_accuracy: 0.7789 Epoch 39/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.2289 - accuracy: 0.9178 - val_loss: 0.7987 - val_accuracy: 0.7689 Epoch 40/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2207 - accuracy: 0.9210 - val_loss: 0.7610 - val_accuracy: 0.7829 Epoch 41/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2120 - accuracy: 0.9239 - val_loss: 0.8622 - val_accuracy: 0.7623 Epoch 42/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.2121 - accuracy: 0.9230 - val_loss: 0.8155 - val_accuracy: 0.7716 Epoch 43/50 50000/50000 [==============================] - 3s 54us/sample - loss: 0.1995 - accuracy: 0.9275 - val_loss: 0.8227 - val_accuracy: 0.7730 Epoch 44/50 50000/50000 [==============================] - 3s 54us/sample - loss: 0.2005 - accuracy: 0.9276 - val_loss: 0.8268 - val_accuracy: 0.7729 Epoch 45/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.1964 - accuracy: 0.9287 - val_loss: 0.8623 - val_accuracy: 0.7710 Epoch 46/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.1936 - accuracy: 0.9301 - val_loss: 0.8813 - val_accuracy: 0.7619 Epoch 47/50 50000/50000 [==============================] - 3s 56us/sample - loss: 0.1821 - accuracy: 0.9339 - val_loss: 0.7952 - val_accuracy: 0.7820 Epoch 48/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.1829 - accuracy: 0.9346 - val_loss: 0.8254 - val_accuracy: 0.7736 Epoch 49/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.1846 - accuracy: 0.9337 - val_loss: 0.8507 - val_accuracy: 0.7780 Epoch 50/50 50000/50000 [==============================] - 3s 55us/sample - loss: 0.1780 - accuracy: 0.9355 - val_loss: 0.8375 - val_accuracy: 0.7784
results = pd.DataFrame(history.history)
results[['accuracy', 'val_accuracy']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe30421a90>
<Figure size 432x288 with 1 Axes>
results[['loss', 'val_loss']].plot()
      
<matplotlib.axes._subplots.AxesSubplot at 0x7fbe13f36e10>
<Figure size 432x288 with 1 Axes>
results.tail(1)
      
loss accuracy val_loss val_accuracy
49 0.178049 0.93546 0.837536 0.7784
Train Test
Base Accuracy 0.911 0.726
Base Loss 0.274 1.007
Dropout Accuracy 0.83 0.785
Dropout Loss 0.481 0.639
Regularization Accuracy 0.732 0.699
Regularization Loss 0.965 1.052
Early Stopping Accuracy 0.738 0.688
Early Stopping Loss 0.765 0.902
Batch Normalization Accuracy 0.935 0.778
Batch Normalization Loss 0.178 0.838

学習用データセット、テスト用データセット共に正解率の向上が確認できました。ベースモデルほどの乖離もなく、やはりバッチノーマリゼーションは効果的な手法であることが実験からわかりました。

活性化関数

これまで活性化関数として、ReLU 関数を多く使ってきましたが、ニューラルネットワークに活性化関数(非線型変換)を用いる理由は何でしょうか。それは、モデルの表現力を増すためです。

では、なぜ活性化関数を用いるとニューラルネットワークの表現力が増すのでしょうか。線形変換をいくら繋げても、それらは一つの線形変換で表現し直すことができることになり、層を深いディープラーニングの恩恵を預かることができなくなってしまうからです。非線形な関数を表現するには、非線形な関数を活性化関数に入れようというのは直感的に理解できそうです。

また、ReLU 関数の他でよく使われる活性化関数としては、

  • sigmoid 関数
  • tanh 関数

が代表されるので、まずはこちらを押さえておきましょう。

他にも ReLU 関数がよく使われているのは以下のような理由があります。

  • max(0,x)\max(0, x) は単純ゆえに計算コストが低い
  • x>0x > 0 の部分では微分値が常に 1 であるため勾配消失の心配が少なくなる

より高度な活性化関数

上記で説明した単純な関数よりも高度な活性化関数を用いたい場合は、tensorflow.keras.layers にありますので、こちらの公式ドキュメントをご覧ください。

有名な活性化関数を確認しましょう。

def sigmoid(x):
    return 1 / (1+np.exp(-x))

def relu(x):
    return np.maximum(0, x)

def tanh(x):
    return np.tanh(x)

def leaky_relu(x):
    return np.maximum(x, 0.01*x)

def prelu(x, a):
    return np.maximum(x, a*x)

fig = plt.figure(figsize=(10, 6))
x = np.linspace(-10, 10, 1000)

ax = fig.add_subplot(111)
ax.plot(x, sigmoid(x), label='sigmoid')
ax.plot(x, relu(x), label='ReLU')
ax.plot(x, tanh(x), label='tanh')
ax.plot(x, leaky_relu(x), label='leaky_relu')
ax.plot(x, prelu(x, 0.08), label='prelu')

plt.legend()
plt.xlim(-5, 5)
plt.ylim(-1.1, 2)
plt.grid(color='white', linestyle='-')
plt.show();
      
<Figure size 720x432 with 1 Axes>
shareアイコン