ファインチューニング

本章では、ディープラーニングで多く用いられるファインチューニング (fine tuning) について学んでいきます。ファインチューニングとは、異なるデータセットで学習済みのモデルに関して一部または全部を再利用して、新しいモデルを構築する手法です。モデルの構造とパラメータを活用し、特徴抽出器としての機能を果たします。手持ちのデータセットのサンプル数が少ないがために精度があまり出ない場合でも、ファインチューニングを使用すれば、性能が向上する場合があります。

34_3

類似した用語として転移学習 (transfer learning) がありますが、学習済みモデルをそのまま使用するか、一部を使用するかで異なります。

34_2

本章の流れは、

  1. データを準備して理解する
  2. 自作した分類モデルで学習と評価
  3. 事前学習済みモデルの重みを読み込んで学習と評価

上記のように、一度自作のモデルを組んで結果を確認した後に、ファインチューニングをすればどの程度精度が向上するかを見ていきます。

本章の構成

  • 画像分類
  • ファインチューニング

画像分類

前章で扱った手書き文字である MNIST の分類は簡単なネットワークの定義でもある程度の正解率が得られますが、少し難易度をあげた問題設定で試してみましょう。CIFAR10 と呼ばれる以下のような 10 クラスの分類を行います。CIFAR10 も MNIST と同様に、tf.keras.datasets にデータセットが用意されています。

34_1

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2

import tensorflow as tf
      
# GPU が使用可能であることを確認
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
      
[name: "/device:CPU:0" device_type: "CPU" memory_limit: 268435456 locality { } incarnation: 9826463659935877047 , name: "/device:XLA_CPU:0" device_type: "XLA_CPU" memory_limit: 17179869184 locality { } incarnation: 7509225710905663396 physical_device_desc: "device: XLA_CPU device" , name: "/device:XLA_GPU:0" device_type: "XLA_GPU" memory_limit: 17179869184 locality { } incarnation: 13976982802428851630 physical_device_desc: "device: XLA_GPU device" , name: "/device:GPU:0" device_type: "GPU" memory_limit: 15956161332 locality { bus_id: 1 links { } } incarnation: 4970730678248293988 physical_device_desc: "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0" ]

データセットの準備

MNIST と同じようにデータセットを準備しましょう。

# データセットの読み込み
train, test = tf.keras.datasets.cifar10.load_data()
      

こちらのデータセットも MNIST と同じデータセット構成をしています。これまでの流れを参考に TensorFlow で扱えるデータ形式まで変換していきましょう。

# 画像の情報
train[0].shape, test[0].shape
      
((50000, 32, 32, 3), (10000, 32, 32, 3))
# ラベルの情報
train[1].shape, test[1].shape
      
((50000, 1), (10000, 1))
# 学習用データセットとテスト用データセットに対して正規化
x_train = train[0] / 255
x_test = test[0] / 255
      
# 目標値の切り分け
t_train = train[1]
t_test = test[1]
      
# 32bit にキャスト
x_train, x_test = x_train.astype('float32'), x_test.astype('float32')
t_train, t_test = t_train.astype('int32'), t_test.astype('int32')
      

CNN モデルの定義

前章の MNIST で 97% を出したモデルでもう一度試してみましょう。どれくらい精度を出すことができるでしょうか。CNN モデルの概要を再度掲載しておきます。

34_4

import os, random

def reset_seed(seed=0):
    os.environ['PYTHONHASHSEED'] = '0'
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
      

画像サイズが変わっていることに注意してください。MNIST では (28, 28, 1) でしたが、CIFAR10 では (32, 32, 3) になります。

from tensorflow.keras import models,layers

# シードの固定
reset_seed(0)

# モデルの構築
model = models.Sequential([
    # 特徴量抽出
    layers.Conv2D(filters=3, kernel_size=(3, 3), activation='relu', input_shape=(32, 32, 3)),
    layers.MaxPool2D(pool_size=(2, 2)),
    # ベクトル化
    layers.Flatten(),
    # 識別
    layers.Dense(100, activation='relu'),
    layers.Dense(10, activation='softmax')
])
      

モデルの定義が完了しました。summary() メソッドでパラメータを確認します。

model.summary()
      
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 30, 30, 3) 84 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 15, 15, 3) 0 _________________________________________________________________ flatten (Flatten) (None, 675) 0 _________________________________________________________________ dense (Dense) (None, 100) 67600 _________________________________________________________________ dense_1 (Dense) (None, 10) 1010 ================================================================= Total params: 68,694 Trainable params: 68,694 Non-trainable params: 0 _________________________________________________________________

目的関数と最適化手法の選択

今回は最適化の手法に Adam を、目的関数は分類の問題設定のため sparse categorical crossentropy を使用します。

# optimizerの設定
optimizer = tf.keras.optimizers.Adam(lr=0.01)

# モデルのコンパイル
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
      

モデルの学習

バッチサイズ、エポック数を定義して、モデルの学習を実行します。

# モデルの学習
batch_size = 4096
epochs = 30

# 学習の実行
history = model.fit(x_train, t_train,
                batch_size=batch_size,
                epochs=epochs, 
                validation_data=(x_test, t_test))
      
Epoch 1/30 13/13 [==============================] - 4s 295ms/step - loss: 2.1355 - accuracy: 0.2159 - val_loss: 1.9075 - val_accuracy: 0.3281 Epoch 2/30 13/13 [==============================] - 4s 295ms/step - loss: 1.8525 - accuracy: 0.3518 - val_loss: 1.7404 - val_accuracy: 0.3903 Epoch 3/30 13/13 [==============================] - 4s 304ms/step - loss: 1.6645 - accuracy: 0.4066 - val_loss: 1.6260 - val_accuracy: 0.4183 Epoch 4/30 13/13 [==============================] - 4s 315ms/step - loss: 1.5667 - accuracy: 0.4392 - val_loss: 1.5662 - val_accuracy: 0.4408 Epoch 5/30 13/13 [==============================] - 4s 340ms/step - loss: 1.4920 - accuracy: 0.4677 - val_loss: 1.5024 - val_accuracy: 0.4590 Epoch 6/30 13/13 [==============================] - 4s 323ms/step - loss: 1.4366 - accuracy: 0.4878 - val_loss: 1.4795 - val_accuracy: 0.4708 Epoch 7/30 13/13 [==============================] - 4s 316ms/step - loss: 1.4183 - accuracy: 0.4967 - val_loss: 1.4590 - val_accuracy: 0.4815 Epoch 8/30 13/13 [==============================] - 4s 308ms/step - loss: 1.3659 - accuracy: 0.5159 - val_loss: 1.4263 - val_accuracy: 0.4961 Epoch 9/30 13/13 [==============================] - 4s 330ms/step - loss: 1.3280 - accuracy: 0.5305 - val_loss: 1.4218 - val_accuracy: 0.4984 Epoch 10/30 13/13 [==============================] - 5s 367ms/step - loss: 1.3034 - accuracy: 0.5395 - val_loss: 1.4382 - val_accuracy: 0.4936 Epoch 11/30 13/13 [==============================] - 5s 372ms/step - loss: 1.2812 - accuracy: 0.5487 - val_loss: 1.4176 - val_accuracy: 0.4992 Epoch 12/30 13/13 [==============================] - 4s 321ms/step - loss: 1.2476 - accuracy: 0.5585 - val_loss: 1.3834 - val_accuracy: 0.5136 Epoch 13/30 13/13 [==============================] - 4s 314ms/step - loss: 1.2189 - accuracy: 0.5697 - val_loss: 1.3907 - val_accuracy: 0.5157 Epoch 14/30 13/13 [==============================] - 4s 328ms/step - loss: 1.2012 - accuracy: 0.5757 - val_loss: 1.4028 - val_accuracy: 0.5109 Epoch 15/30 13/13 [==============================] - 4s 342ms/step - loss: 1.1776 - accuracy: 0.5873 - val_loss: 1.3729 - val_accuracy: 0.5220 Epoch 16/30 13/13 [==============================] - 6s 430ms/step - loss: 1.1451 - accuracy: 0.5985 - val_loss: 1.4114 - val_accuracy: 0.5139 Epoch 17/30 13/13 [==============================] - 6s 487ms/step - loss: 1.1438 - accuracy: 0.5997 - val_loss: 1.3796 - val_accuracy: 0.5170 Epoch 18/30 13/13 [==============================] - 5s 415ms/step - loss: 1.1109 - accuracy: 0.6101 - val_loss: 1.3757 - val_accuracy: 0.5218 Epoch 19/30 13/13 [==============================] - 4s 345ms/step - loss: 1.0901 - accuracy: 0.6172 - val_loss: 1.3667 - val_accuracy: 0.5231 Epoch 20/30 13/13 [==============================] - 4s 340ms/step - loss: 1.0745 - accuracy: 0.6235 - val_loss: 1.4226 - val_accuracy: 0.5191 Epoch 21/30 13/13 [==============================] - 4s 331ms/step - loss: 1.0759 - accuracy: 0.6207 - val_loss: 1.3936 - val_accuracy: 0.5222 Epoch 22/30 13/13 [==============================] - 5s 377ms/step - loss: 1.0521 - accuracy: 0.6311 - val_loss: 1.3942 - val_accuracy: 0.5258 Epoch 23/30 13/13 [==============================] - 5s 353ms/step - loss: 1.0237 - accuracy: 0.6433 - val_loss: 1.4094 - val_accuracy: 0.5190 Epoch 24/30 13/13 [==============================] - 5s 367ms/step - loss: 1.0276 - accuracy: 0.6403 - val_loss: 1.4246 - val_accuracy: 0.5243 Epoch 25/30 13/13 [==============================] - 5s 347ms/step - loss: 1.0133 - accuracy: 0.6449 - val_loss: 1.4068 - val_accuracy: 0.5240 Epoch 26/30 13/13 [==============================] - 5s 368ms/step - loss: 0.9911 - accuracy: 0.6530 - val_loss: 1.4217 - val_accuracy: 0.5253 Epoch 27/30 13/13 [==============================] - 4s 342ms/step - loss: 0.9721 - accuracy: 0.6608 - val_loss: 1.4358 - val_accuracy: 0.5170 Epoch 28/30 13/13 [==============================] - 4s 339ms/step - loss: 0.9627 - accuracy: 0.6654 - val_loss: 1.4374 - val_accuracy: 0.5216 Epoch 29/30 13/13 [==============================] - 4s 330ms/step - loss: 0.9525 - accuracy: 0.6707 - val_loss: 1.4802 - val_accuracy: 0.5144 Epoch 30/30 13/13 [==============================] - 4s 326ms/step - loss: 0.9445 - accuracy: 0.6694 - val_loss: 1.4464 - val_accuracy: 0.5269

10 クラスの分類であるため、ベースラインとなる正解率が 10% (=1/10) とすると、正解率を約 48 % 程度までは上昇させることができていますが、それでもまだ正解率は高いとは言えない結果になっています。

ファインチューニング

冒頭で説明した通り、ファインチューニングを実装します。学習済みモデルは世界中で公開されており、それぞれのタスクに合わせて学習済みモデルを使用する必要があります。

34_3

今回は、世界最大の画像認識コンペティション ImageNet Large Scale Visual Recognition Competition (ILSVRC) で使用されている 1000 クラスの物体を分類するタスクの学習済みモデルを利用します。この ILSVRC の学習済みモデルはフレームワーク側で用意されていることが多く、簡単に使い始めることができます。まずは、この学習済みモデルで試してみると良いでしょう。また、ネットワークの構造は VGG16 というモデルを使用します。今でも幅広く活用されているモデルです。

学習済みモデルは tf.keras.applications 以下に置かれています。今回使用するモデル以外にも画像分類モデルであれば、

  • VGG16, 19
  • ResNet
  • DenseNet
  • Inception v3
  • MobileNet v2
  • NASNet

などがあります。

最初に、ImageNet で学習された重みを持った VGG16 モデルをインスタンス化します。インスタンス化するときに、include_top=False とすることで、全結合層レイヤーを除くことができます。全結合層のレイヤーは、識別するためのレイヤーであり、ImageNet では 1000 クラス分類ですが、今回は 10 クラス分類なので最終層に全結合層を自分の問題設定に合うように追加します。学習済みの重みを持ったモデルは、特徴抽出器として役立ちます。また、学習済みの重みを使うためには weights='imagenet' とします。

  • ファインチューニングする際の注意点
    • include_top=False とすると事前学習済みの全結合層を除く
    • weights='imagenet' とすると事前学習済みの重みを引き継ぐ

それでは実装していきます。

from tensorflow.keras.applications import resnet, VGG16

# 学習済みモデルをインスタンス化
base_model = VGG16(input_shape=(224, 224, 3),
                                      include_top=False, weights='imagenet')
      
base_model.summary()
      
Model: "vgg16" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 ================================================================= Total params: 14,714,688 Trainable params: 14,714,688 Non-trainable params: 0 _________________________________________________________________

モデルの概要を確認したとおり、最終層が MaxPooling で終わっています。こちらは include_top=False にしたためです。

それではデータの準備から進めていきたいのですが、今回画像サイズを大きくしますので 50,000 枚すべてのデータを使用すると膨大な量の演算が必要になってきます。そのため、今回は学習用データセットを 10,000 枚、テスト用データセットを 5,000 枚にして実装していきます。

今回のようにランダムに整数で値を取得したい場合に便利なのが、np.random.randint() 関数です。np.random モジュール内のその他の乱数生成方法についてはこちらを確認してください。引数には、

  • low : 最小値
  • high : 最大値
  • size : 配列のサイズ

があります。例えば今回の学習用データセットから 10,000 枚の画像を取り出したい場合には np.random.randint(0, 50000, 10000) と実行します。

# ランダムにデータを取得する
train_choice = np.random.randint(low=0, high=50000, size=10000)
test_choice = np.random.randint(low=0, high=10000, size=5000)
      
# データの準備
x_train = train[0][train_choice]
x_test = test[0][test_choice]
t_train = train[1][train_choice].astype('int32') 
t_test = test[1][test_choice].astype('int32') 
      

また、VGG16 では画像サイズ (224, 224, 3) の画像で事前学習されているため、データサイズを変更(リサイズ)する必要があります。それぞれの事前学習済みモデルを使用する際にも基本的には学習時と同サイズにする必要があります。

リサイズ方法にはいくつも種類がありますが、今回は OpenCV の resize() を使用します。引数に画像 (src) とリサイズ後のサイズ (dsize) を指定しましょう。

_train, _test = [], []

# 画像サイズを 224 × 224 にリサイズしてリストに格納
for img in x_train:
    _train.append(cv2.resize(src=img, dsize=(224, 224)))
for img in x_test:    
    _test.append(cv2.resize(src=img, dsize=(224, 224)))
      
# リストから ndarray に変換し、正規化
x_train = np.array(_train, dtype='float32') / 255.0
x_test = np.array(_test, dtype='float32') / 255.0
      
x_train.shape, x_test.shape
      
((10000, 224, 224, 3), (5000, 224, 224, 3))

全結合層を追加

テンソルからベクトルに直す操作として、Flatten を使いますが、学習すべきパラメータ数が膨大に増加してしまいます。その代わりによく使われるレイヤーとして、GlobalAveragePooling2D があります。頭文字を取って GAPと呼ぶことも覚えておきましょう。

34_5

もしも最終層の特徴マップのサイズが (7, 7, 512) であれば、Flatten を使用すると、7×7×512 のベクトルになります。GlobalAveragePooling2D を使えば、特徴マップのチャネルごとの平均値を出力してくれるので、512512 ベクトルになります。1/491/49 のパラメータ削減になります。

そして、10 クラス分類なのでノード数 10 の全結合層も最終層に追加します。

reset_seed(0)

# モデルの定義
finetuned_model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(512, activation='relu'),
    layers.Dense(10, activation='softmax')
])
      

モデルをコンパイルします。最適化手法、目的関数、監視する評価指標を設定しましょう。今回は最適化手法に SGD を使用してみます。

optimizer = tf.keras.optimizers.SGD(lr=0.01)

# モデルのコンパイル
finetuned_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
      
finetuned_model.summary()
      
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= vgg16 (Model) (None, 7, 7, 512) 14714688 _________________________________________________________________ global_average_pooling2d (Gl (None, 512) 0 _________________________________________________________________ dense_2 (Dense) (None, 512) 262656 _________________________________________________________________ dense_3 (Dense) (None, 10) 5130 ================================================================= Total params: 14,982,474 Trainable params: 14,982,474 Non-trainable params: 0 _________________________________________________________________

モデルの学習

学習させる前に、一度学習前のモデルでテスト用データセットに対して順伝播させてみます。学習後のモデルと比較するためです。順伝播する方法は、evaluate() で行うことができます。

loss, accuracy = finetuned_model.evaluate(x_test, t_test)
      
60/157 [==========>...................] - ETA: 7:48 - loss: 2.4370 - accuracy: 0.0901
print(f'loss : {loss}, acuracy : {accuracy}')
      

正解率は 9% 程度です。ここから 10 エポックでどの程度正解率が向上するのか見ていきましょう。

# モデルの学習
history = finetuned_model.fit(x_train, t_train, 
                    epochs=10, 
                    batch_size=32,
                    validation_data=(x_test, t_test))
      

予測精度の評価

学習結果を確認します。

results = pd.DataFrame(history.history)
results.tail(3)
      
# 損失を可視化
results[['loss', 'val_loss']].plot(title='loss')
plt.xlabel('epochs');
      
# 正解率を可視化
results[['accuracy', 'val_accuracy']].plot(title='accuracy')
plt.xlabel('epochs');
      

上記の結果の通り、わずか 10 エポック分の学習で正解率を 86% まで高めることができました。さらなる学習やハイパーパラメータの調整などでさらに正解率を高められることが期待できます。ぜひ、挑戦してみてください。

shareアイコン