代表的なモデル

本章では、転移学習でよく用いられる代表的な CNN モデルを紹介します。

  • VGGNet
  • GoogLeNet / Inception
  • MobileNet
  • ResNet

これらは、物体検出やセマンティックセグメンテーションに使用されるモデルのバックボーンにも利用されることも多いため、押さえておく必要があります。

Functional API

複数の入出力が発生するようなモデルの場合、これまで学んできた Sequential API では記述できません。

したがって、Functional API での記述方法をはじめに抑えておきましょう。

import tensorflow as tf
from tensorflow.keras import models, layers

import numpy as np
      

Sequential API によるシンプルなネットワーク

# Sequential API モデル構築
model = models.Sequential([
    layers.Dense(64, activation='relu', input_shape=(784, )),
    layers.Dense(64, activation='relu')
])

model.summary()
      
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 64) 50240 _________________________________________________________________ dense_1 (Dense) (None, 64) 4160 ================================================================= Total params: 54,400 Trainable params: 54,400 Non-trainable params: 0 _________________________________________________________________

このネットワークを書き換えてみましょう。

Functional API による記述

# FunctionalAPI によるモデル構築
inputs = layers.Input(shape=(784, ))

x = layers.Dense(64, activation='relu')(inputs)
outputs = layers.Dense(64, activation='relu')(x)

model = models.Model(inputs=inputs, outputs=outputs)
      

Fucntional API では上記のように、モデルの入出力を明示的に記述します。

model.summary()
      
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 784)] 0 _________________________________________________________________ dense_2 (Dense) (None, 64) 50240 _________________________________________________________________ dense_3 (Dense) (None, 64) 4160 ================================================================= Total params: 54,400 Trainable params: 54,400 Non-trainable params: 0 _________________________________________________________________

VGGNet(2014)

VGGNet は 2014 年に ILSVRC で画像分類タスクで 2 位となったモデルです。1 位ではないのですが、非常にシンプルでわかりやすいアーキテクチャで高い精度をたたき出したため、よく用いられます。

特徴は 3 点あります。

  • 3×33\times3 フィルタのみを使用
  • 同一チャネルの複数の畳み込み層と Max Pooling を 1 セットとし、繰り返す
  • Max Pooling 後の出力チャネル数を 2 倍にする

3×33\times3 のフィルタが採用されているのは、それが上下左右中心の情報を受容できる一番小さなサイズであるためです。

また、3×33\times3 での畳み込みを繰り返すことで、より大きなフィルタサイズを持つフィルタでの畳み込みを近似します。例えば 7×77\times7 での畳み込みを 1 回行う場合と、3×33\times3 での畳み込みを 3 回繰り返すこと場合では、同じサイズ (H, W) の特徴マップを出力することができます。下記の条件で考えてみましょう。

  • 入力の特徴マップのチャネル数が 16
  • 出力の特徴マップのチャネル数が 32

7×77\times7 の畳み込みを 1 回行う場合

  • 画像のサイズの推移:IN (32, 32, 16) → OUT(26, 26, 32)
  • パラメータ数:7×7×16×32=25,0887\times7\times16\times32 = 25,088

3×33\times3 の畳み込みを 3 回行う場合

  • 画像のサイズの推移:IN(32, 32, 16)→ (30, 30, 32)→(28, 28, 32)→ OUT(26, 26, 32)
  • パラメータ数:3×3×16×32×3=13,8243\times3\times16\times32\times3 = 13,824

となり、バイアスの影響を考慮していませんが、同じ範囲を半分程度のパラメータ数で見られることになります。さらにパラメータ数が削減できただけでなく、精度の向上も見られたため、ここから CNN モデルは 3×33\times3 をフィルタサイズの中心として研究されることになりました。

37_2

def vgg(input_shape, n_classes):
    
    input = layers.Input(shape=input_shape)
    
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(input)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(2, strides=2, padding='same')(x)
    
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(2, strides=2, padding='same')(x)
    
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(2, strides=2, padding='same')(x)
    
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(2, strides=2, padding='same')(x)
    
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(2, strides=2, padding='same')(x)
    
    x = layers.Flatten()(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dense(4096, activation='relu')(x)
    
    output = layers.Dense(n_classes, activation='softmax')(x)
    
    model = models.Model(input, output)
    
    return model
      
model = vgg((224, 224, 3), 1000)
model.summary()
      
Model: "model_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_2 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv2d (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ conv2d_1 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ conv2d_3 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ conv2d_5 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ conv2d_6 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ conv2d_8 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ conv2d_9 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ conv2d_11 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ conv2d_12 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 7, 7, 512) 0 _________________________________________________________________ flatten (Flatten) (None, 25088) 0 _________________________________________________________________ dense_4 (Dense) (None, 4096) 102764544 _________________________________________________________________ dense_5 (Dense) (None, 4096) 16781312 _________________________________________________________________ dense_6 (Dense) (None, 1000) 4097000 ================================================================= Total params: 138,357,544 Trainable params: 138,357,544 Non-trainable params: 0 _________________________________________________________________

GoogLeNet / Inception(2014)

ILSVRC 2014 における優勝モデルです。LeNet のオマージュにより名付けられており、Inception という別名も持ちます。こちらのネットワークは提案以降、改良が続けられており、2020 年 3 月現在では Inception v4 が最新です。

GoogLeNet がいわゆる Inception v1 に当たります。

Inception モジュール

GoogLeNet は Inception モジュールとして、複数のネットワークを1つにまとめ、モジュールを積み重ねる、Network In Network の構成がなされています。

Inception モジュールの内部では、下図 (a) のように、異なるフィルタサイズの複数の畳み込み層を同時に通して、結合する処理が行われます。これはパラメータ数を削減して計算コストを下げつつ、複雑なアーキテクチャを組むための工夫です。

どの程度パラメータを削減できるのか、具体的に下記の条件で考えてみましょう。

  • 入力の特徴マップのチャネル数が 192192
  • 出力の特徴マップのチャネル数が 9696

通常の 5×55\times5 の畳み込み層において、バイアスを除いたパラメータ数は

5×5×192×96=460,800\begin{array}{c} 5 \times 5 \times 192 \times 96 = 460,800 \end{array}

となります。一方で、naive(a) の Inception モジュールを使用した場合

(1×1×192×24)+(3×3×192×24)+(5×5×192×24)=161,280\begin{array}{c} &(1 \times 1 \times 192 \times 24)\\ &+ (3 \times 3 \times 192 \times 24)\\ &+ ( 5 \times 5 \times 192 \times 24)\\ &= 161,280 \end{array}

と大きく削減できることがわかります。

Inception モジュールはさらにチャネル方向の次元削減も考慮したパターンも提案されており、上図 (b) がその構成です。

見ての通り、3×33\times35×55\times5 の畳み込み層の前に 1×11\times1 の畳み込み層を重ねています。1×11\times1 の畳み込み層を通して一度チャネル数を減らしてから、より計算コストのかかる 3×33\times35×55\times5 の畳み込み層に通すことで、全体のパラメータ数を減らす考え方です。

こちらも具体的に数を追ってみてみましょう。以下の条件で計算します。

  • 入力の特徴マップのチャネル数が 192192
  • 中間 (1×11\times1 畳み込み層を通した後)の特徴マップのチャネル数が 1616
  • 出力の特徴マップのチャネル数が 9696

(1×1×192×24)+(1×1×192×16)+(3×3×16×24)+(1×1×192×16)+(5×5×16×24)+(1×1×192×24)=28,416\begin{aligned} &(1 \times 1 \times 192 \times 24) \\ &+ (1 \times 1 \times 192 \times 16) + (3 \times 3 \times 16 \times 24) \\ &+ (1 \times 1 \times 192 \times 16) + (5 \times 5 \times 16 \times 24) \\ &+ (1 \times 1 \times 192 \times 24) \\ &= 28,416 \end{aligned}

(a) のパターンでは、161,280161,280 だったので、大きくパラメータ数を削減できたことが確認できます。

(b) パターンの Inception モジュールは以下のように実装することが可能です。

def inception_block(x, f):
    t1 = layers.Conv2D(f[0], 1, activation='relu')(x)
    
    t2 = layers.Conv2D(f[1], 1, activation='relu')(x)
    t2 = layers.Conv2D(f[2], 3, padding='same', activation='relu')(t2)
    
    t3 = layers.Conv2D(f[3], 1, activation='relu')(x)
    t3 = layers.Conv2D(f[4], 5, padding='same', activation='relu')(t3)

    t4 = layers.MaxPool2D(3, 1, padding='same')(x)
    t4 = layers.Conv2D(f[5], 1, activation='relu')(t4)

    out = layers.concatenate([t1, t2, t3, t4])
    
    return out
      

Global Average Pooling (GAP)

また、Global Average Pooling も GoogLeNet で採用されました。こちらは、CNN で特徴マップを全結合層へつなぐ際に使用されます。

これまでは特徴マップの各画素を順番に切り取って並べることでベクトル化していたのに対し、GAP では一つの特徴マップから Average Pooling で 1×11\times1 のサイズにしたものを並べてベクトル化します。

37_4

def googlenet(input_shape, n_classes):
  
    def inception_block(x, f):
        t1 = layers.Conv2D(f[0], 1, activation='relu')(x)

        t2 = layers.Conv2D(f[1], 1, activation='relu')(x)
        t2 = layers.Conv2D(f[2], 3, padding='same', activation='relu')(t2)

        t3 = layers.Conv2D(f[3], 1, activation='relu')(x)
        t3 = layers.Conv2D(f[4], 5, padding='same', activation='relu')(t3)

        t4 = layers.MaxPool2D(3, 1, padding='same')(x)
        t4 = layers.Conv2D(f[5], 1, activation='relu')(t4)

        out = layers.concatenate([t1, t2, t3, t4])

        return out

    input = layers.Input(input_shape)

    x = layers.Conv2D(64, 7, strides=2, padding='same', activation='relu')(input)
    x = layers.MaxPool2D(3, strides=2, padding='same')(x)

    x = layers.Conv2D(64, 1, activation='relu')(x)
    x = layers.Conv2D(192, 3, padding='same', activation='relu')(x)
    x = layers.MaxPool2D(3, strides=2)(x)

    x = inception_block(x, [64, 96, 128, 16, 32, 32])
    x = inception_block(x, [128, 128, 192, 32, 96, 64])
    x = layers.MaxPool2D(3, strides=2, padding='same')(x)

    x = inception_block(x, [192, 96, 208, 16, 48, 64])
    x = inception_block(x, [160, 112, 224, 24, 64, 64])
    x = inception_block(x, [128, 128, 256, 24, 64, 64])
    x = inception_block(x, [112, 144, 288, 32, 64, 64])
    x = inception_block(x, [256, 160, 320, 32, 128, 128])
    x = layers.MaxPool2D(3, strides=2, padding='same')(x)

    x = inception_block(x, [256, 160, 320, 32, 128, 128])
    x = inception_block(x, [384, 192, 384, 48, 128, 128])

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.4)(x)

    output = layers.Dense(n_classes, activation='softmax')(x)

    model = models.Model(input, output)

    return model
      
model = googlenet((224, 224, 3), 1000)
model.summary()
      
Model: "model_2" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_3 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________ conv2d_13 (Conv2D) (None, 112, 112, 64) 9472 input_3[0][0] __________________________________________________________________________________________________ max_pooling2d_5 (MaxPooling2D) (None, 56, 56, 64) 0 conv2d_13[0][0] __________________________________________________________________________________________________ conv2d_14 (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_5[0][0] __________________________________________________________________________________________________ conv2d_15 (Conv2D) (None, 56, 56, 192) 110784 conv2d_14[0][0] __________________________________________________________________________________________________ max_pooling2d_6 (MaxPooling2D) (None, 27, 27, 192) 0 conv2d_15[0][0] __________________________________________________________________________________________________ conv2d_17 (Conv2D) (None, 27, 27, 96) 18528 max_pooling2d_6[0][0] __________________________________________________________________________________________________ conv2d_19 (Conv2D) (None, 27, 27, 16) 3088 max_pooling2d_6[0][0] __________________________________________________________________________________________________ max_pooling2d_7 (MaxPooling2D) (None, 27, 27, 192) 0 max_pooling2d_6[0][0] __________________________________________________________________________________________________ conv2d_16 (Conv2D) (None, 27, 27, 64) 12352 max_pooling2d_6[0][0] __________________________________________________________________________________________________ conv2d_18 (Conv2D) (None, 27, 27, 128) 110720 conv2d_17[0][0] __________________________________________________________________________________________________ conv2d_20 (Conv2D) (None, 27, 27, 32) 12832 conv2d_19[0][0] __________________________________________________________________________________________________ conv2d_21 (Conv2D) (None, 27, 27, 32) 6176 max_pooling2d_7[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 27, 27, 256) 0 conv2d_16[0][0] conv2d_18[0][0] conv2d_20[0][0] conv2d_21[0][0] __________________________________________________________________________________________________ conv2d_23 (Conv2D) (None, 27, 27, 128) 32896 concatenate[0][0] __________________________________________________________________________________________________ conv2d_25 (Conv2D) (None, 27, 27, 32) 8224 concatenate[0][0] __________________________________________________________________________________________________ max_pooling2d_8 (MaxPooling2D) (None, 27, 27, 256) 0 concatenate[0][0] __________________________________________________________________________________________________ conv2d_22 (Conv2D) (None, 27, 27, 128) 32896 concatenate[0][0] __________________________________________________________________________________________________ conv2d_24 (Conv2D) (None, 27, 27, 192) 221376 conv2d_23[0][0] __________________________________________________________________________________________________ conv2d_26 (Conv2D) (None, 27, 27, 96) 76896 conv2d_25[0][0] __________________________________________________________________________________________________ conv2d_27 (Conv2D) (None, 27, 27, 64) 16448 max_pooling2d_8[0][0] __________________________________________________________________________________________________ concatenate_1 (Concatenate) (None, 27, 27, 480) 0 conv2d_22[0][0] conv2d_24[0][0] conv2d_26[0][0] conv2d_27[0][0] __________________________________________________________________________________________________ max_pooling2d_9 (MaxPooling2D) (None, 14, 14, 480) 0 concatenate_1[0][0] __________________________________________________________________________________________________ conv2d_29 (Conv2D) (None, 14, 14, 96) 46176 max_pooling2d_9[0][0] __________________________________________________________________________________________________ conv2d_31 (Conv2D) (None, 14, 14, 16) 7696 max_pooling2d_9[0][0] __________________________________________________________________________________________________ max_pooling2d_10 (MaxPooling2D) (None, 14, 14, 480) 0 max_pooling2d_9[0][0] __________________________________________________________________________________________________ conv2d_28 (Conv2D) (None, 14, 14, 192) 92352 max_pooling2d_9[0][0] __________________________________________________________________________________________________ conv2d_30 (Conv2D) (None, 14, 14, 208) 179920 conv2d_29[0][0] __________________________________________________________________________________________________ conv2d_32 (Conv2D) (None, 14, 14, 48) 19248 conv2d_31[0][0] __________________________________________________________________________________________________ conv2d_33 (Conv2D) (None, 14, 14, 64) 30784 max_pooling2d_10[0][0] __________________________________________________________________________________________________ concatenate_2 (Concatenate) (None, 14, 14, 512) 0 conv2d_28[0][0] conv2d_30[0][0] conv2d_32[0][0] conv2d_33[0][0] __________________________________________________________________________________________________ conv2d_35 (Conv2D) (None, 14, 14, 112) 57456 concatenate_2[0][0] __________________________________________________________________________________________________ conv2d_37 (Conv2D) (None, 14, 14, 24) 12312 concatenate_2[0][0] __________________________________________________________________________________________________ max_pooling2d_11 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_2[0][0] __________________________________________________________________________________________________ conv2d_34 (Conv2D) (None, 14, 14, 160) 82080 concatenate_2[0][0] __________________________________________________________________________________________________ conv2d_36 (Conv2D) (None, 14, 14, 224) 226016 conv2d_35[0][0] __________________________________________________________________________________________________ conv2d_38 (Conv2D) (None, 14, 14, 64) 38464 conv2d_37[0][0] __________________________________________________________________________________________________ conv2d_39 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_11[0][0] __________________________________________________________________________________________________ concatenate_3 (Concatenate) (None, 14, 14, 512) 0 conv2d_34[0][0] conv2d_36[0][0] conv2d_38[0][0] conv2d_39[0][0] __________________________________________________________________________________________________ conv2d_41 (Conv2D) (None, 14, 14, 128) 65664 concatenate_3[0][0] __________________________________________________________________________________________________ conv2d_43 (Conv2D) (None, 14, 14, 24) 12312 concatenate_3[0][0] __________________________________________________________________________________________________ max_pooling2d_12 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_3[0][0] __________________________________________________________________________________________________ conv2d_40 (Conv2D) (None, 14, 14, 128) 65664 concatenate_3[0][0] __________________________________________________________________________________________________ conv2d_42 (Conv2D) (None, 14, 14, 256) 295168 conv2d_41[0][0] __________________________________________________________________________________________________ conv2d_44 (Conv2D) (None, 14, 14, 64) 38464 conv2d_43[0][0] __________________________________________________________________________________________________ conv2d_45 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_12[0][0] __________________________________________________________________________________________________ concatenate_4 (Concatenate) (None, 14, 14, 512) 0 conv2d_40[0][0] conv2d_42[0][0] conv2d_44[0][0] conv2d_45[0][0] __________________________________________________________________________________________________ conv2d_47 (Conv2D) (None, 14, 14, 144) 73872 concatenate_4[0][0] __________________________________________________________________________________________________ conv2d_49 (Conv2D) (None, 14, 14, 32) 16416 concatenate_4[0][0] __________________________________________________________________________________________________ max_pooling2d_13 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_4[0][0] __________________________________________________________________________________________________ conv2d_46 (Conv2D) (None, 14, 14, 112) 57456 concatenate_4[0][0] __________________________________________________________________________________________________ conv2d_48 (Conv2D) (None, 14, 14, 288) 373536 conv2d_47[0][0] __________________________________________________________________________________________________ conv2d_50 (Conv2D) (None, 14, 14, 64) 51264 conv2d_49[0][0] __________________________________________________________________________________________________ conv2d_51 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_13[0][0] __________________________________________________________________________________________________ concatenate_5 (Concatenate) (None, 14, 14, 528) 0 conv2d_46[0][0] conv2d_48[0][0] conv2d_50[0][0] conv2d_51[0][0] __________________________________________________________________________________________________ conv2d_53 (Conv2D) (None, 14, 14, 160) 84640 concatenate_5[0][0] __________________________________________________________________________________________________ conv2d_55 (Conv2D) (None, 14, 14, 32) 16928 concatenate_5[0][0] __________________________________________________________________________________________________ max_pooling2d_14 (MaxPooling2D) (None, 14, 14, 528) 0 concatenate_5[0][0] __________________________________________________________________________________________________ conv2d_52 (Conv2D) (None, 14, 14, 256) 135424 concatenate_5[0][0] __________________________________________________________________________________________________ conv2d_54 (Conv2D) (None, 14, 14, 320) 461120 conv2d_53[0][0] __________________________________________________________________________________________________ conv2d_56 (Conv2D) (None, 14, 14, 128) 102528 conv2d_55[0][0] __________________________________________________________________________________________________ conv2d_57 (Conv2D) (None, 14, 14, 128) 67712 max_pooling2d_14[0][0] __________________________________________________________________________________________________ concatenate_6 (Concatenate) (None, 14, 14, 832) 0 conv2d_52[0][0] conv2d_54[0][0] conv2d_56[0][0] conv2d_57[0][0] __________________________________________________________________________________________________ max_pooling2d_15 (MaxPooling2D) (None, 7, 7, 832) 0 concatenate_6[0][0] __________________________________________________________________________________________________ conv2d_59 (Conv2D) (None, 7, 7, 160) 133280 max_pooling2d_15[0][0] __________________________________________________________________________________________________ conv2d_61 (Conv2D) (None, 7, 7, 32) 26656 max_pooling2d_15[0][0] __________________________________________________________________________________________________ max_pooling2d_16 (MaxPooling2D) (None, 7, 7, 832) 0 max_pooling2d_15[0][0] __________________________________________________________________________________________________ conv2d_58 (Conv2D) (None, 7, 7, 256) 213248 max_pooling2d_15[0][0] __________________________________________________________________________________________________ conv2d_60 (Conv2D) (None, 7, 7, 320) 461120 conv2d_59[0][0] __________________________________________________________________________________________________ conv2d_62 (Conv2D) (None, 7, 7, 128) 102528 conv2d_61[0][0] __________________________________________________________________________________________________ conv2d_63 (Conv2D) (None, 7, 7, 128) 106624 max_pooling2d_16[0][0] __________________________________________________________________________________________________ concatenate_7 (Concatenate) (None, 7, 7, 832) 0 conv2d_58[0][0] conv2d_60[0][0] conv2d_62[0][0] conv2d_63[0][0] __________________________________________________________________________________________________ conv2d_65 (Conv2D) (None, 7, 7, 192) 159936 concatenate_7[0][0] __________________________________________________________________________________________________ conv2d_67 (Conv2D) (None, 7, 7, 48) 39984 concatenate_7[0][0] __________________________________________________________________________________________________ max_pooling2d_17 (MaxPooling2D) (None, 7, 7, 832) 0 concatenate_7[0][0] __________________________________________________________________________________________________ conv2d_64 (Conv2D) (None, 7, 7, 384) 319872 concatenate_7[0][0] __________________________________________________________________________________________________ conv2d_66 (Conv2D) (None, 7, 7, 384) 663936 conv2d_65[0][0] __________________________________________________________________________________________________ conv2d_68 (Conv2D) (None, 7, 7, 128) 153728 conv2d_67[0][0] __________________________________________________________________________________________________ conv2d_69 (Conv2D) (None, 7, 7, 128) 106624 max_pooling2d_17[0][0] __________________________________________________________________________________________________ concatenate_8 (Concatenate) (None, 7, 7, 1024) 0 conv2d_64[0][0] conv2d_66[0][0] conv2d_68[0][0] conv2d_69[0][0] __________________________________________________________________________________________________ global_average_pooling2d (Globa (None, 1024) 0 concatenate_8[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 1024) 0 global_average_pooling2d[0][0] __________________________________________________________________________________________________ dense_7 (Dense) (None, 1000) 1025000 dropout[0][0] ================================================================================================== Total params: 6,998,552 Trainable params: 6,998,552 Non-trainable params: 0 __________________________________________________________________________________________________

ResNet(2015)

Residual モジュール

ResNet の特徴は、名前の由来にもなっている Residual モジュールを採用している点です。

層を深くすることで精度が向上することがわかり、深くしたいモチベーションがある一方で、深くしすぎると逆伝播時に勾配消失してしまう問題がありました。この問題に対する工夫として提案されたのが Residual モジュールです。

37_5

出典:https://arxiv.org/pdf/1512.03385.pdf

仕組みとしては、畳み込み層への入力を分岐させ、1 層先の畳み込み層の出力と結合させます。

こうすることで、逆伝播の際に微分を行うと、

ddx(f(x)+x)=ddxf(x)+1\frac{d}{d x}(f(x) + x) = \frac{d}{d x}f(x) + 1

と通常に加えて 1 増えます。これにより、勾配消失を防いで層を深くすることが可能となりました。

Bottleneck モジュール

Residualモジュールに、1×11\times1 畳み込みを加えて、パラメータを削減しより効率的に学習を行えるモジュールも提案されました。

37_6

出典:https://arxiv.org/pdf/1512.03385.pdf

こちらは、間でチャネル数がいったん小さくなった後、元の大きさに戻って出ていく様が、ボトルネックの形状と似ていることから、Bottleneck モジュールとも呼ばれます。

He の初期化

活性化関数に ReLU を用いる際の、最適な重みの初期値もここで提案されました。重みの初期値は、平均 0 標準偏差 1 の正規分布からランダムに設定されました。

He の初期化では、前層から渡されるノード数が nn 個である場合には、重みの初期値を平均 0、標準偏差 2n\sqrt{\frac{2}{n}} の正規分布から生成します。

Batch Normarlization

精度向上でも紹介した Batch Normalization もここで提案されました。

様々なバリエーション

ResNet は最大 152 層まで深くしたアーキテクチャが提案されています。

各バリエーションの構成は以下の通りです。

  • ResNet18
  • ResNet34
  • ResNet50
  • ResNet101
  • ResNet152

37_7

出典:https://arxiv.org/pdf/1512.03385.pdf

def resnet(input_shape, n_classes):

    def conv_bn_rl(x, f, k=1, s=1, p='same'):
        x = layers.Conv2D(f, k, strides=s, padding=p)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        return x


    def identity_block(tensor, f):
        x = conv_bn_rl(tensor, f)
        x = conv_bn_rl(x, f, 3)
        x = layers.Conv2D(4*f, 1)(x)
        x = layers.BatchNormalization()(x)

        x = layers.add([x, tensor])
        output = layers.ReLU()(x)
        return output


    def conv_block(tensor, f, s):
        x = conv_bn_rl(tensor, f)
        x = conv_bn_rl(x, f, 3, s)
        x = layers.Conv2D(4*f, 1)(x)
        x = layers.BatchNormalization()(x)

        shortcut = layers.Conv2D(4*f, 1, strides=s)(tensor)
        shortcut = layers.BatchNormalization()(shortcut)

        x = layers.add([x, shortcut])
        output = layers.ReLU()(x)
        return output


    def resnet_block(x, f, r, s=2):
        x = conv_block(x, f, s)
        for _ in range(r-1):
            x = identity_block(x, f)
        return x


    input = layers.Input(input_shape)

    x = conv_bn_rl(input, 64, 7, 2)
    x = layers.MaxPool2D(3, strides=2, padding='same')(x)

    x = resnet_block(x, 64, 3, 1)
    x = resnet_block(x, 128, 4)
    x = resnet_block(x, 256, 6)
    x = resnet_block(x, 512, 3)

    x = layers.GlobalAveragePooling2D()(x)

    output = layers.Dense(n_classes, activation='softmax')(x)

    model = models.Model(input, output)
    return model
      
model = resnet((224, 224, 3), 1000)
model.summary()
      
Model: "model_3" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_4 (InputLayer) [(None, 224, 224, 3) 0 __________________________________________________________________________________________________ conv2d_70 (Conv2D) (None, 112, 112, 64) 9472 input_4[0][0] __________________________________________________________________________________________________ batch_normalization (BatchNorma (None, 112, 112, 64) 256 conv2d_70[0][0] __________________________________________________________________________________________________ re_lu (ReLU) (None, 112, 112, 64) 0 batch_normalization[0][0] __________________________________________________________________________________________________ max_pooling2d_18 (MaxPooling2D) (None, 56, 56, 64) 0 re_lu[0][0] __________________________________________________________________________________________________ conv2d_71 (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_18[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 56, 56, 64) 256 conv2d_71[0][0] __________________________________________________________________________________________________ re_lu_1 (ReLU) (None, 56, 56, 64) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ conv2d_72 (Conv2D) (None, 56, 56, 64) 36928 re_lu_1[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 56, 56, 64) 256 conv2d_72[0][0] __________________________________________________________________________________________________ re_lu_2 (ReLU) (None, 56, 56, 64) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________ conv2d_73 (Conv2D) (None, 56, 56, 256) 16640 re_lu_2[0][0] __________________________________________________________________________________________________ conv2d_74 (Conv2D) (None, 56, 56, 256) 16640 max_pooling2d_18[0][0] __________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None, 56, 56, 256) 1024 conv2d_73[0][0] __________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None, 56, 56, 256) 1024 conv2d_74[0][0] __________________________________________________________________________________________________ add (Add) (None, 56, 56, 256) 0 batch_normalization_3[0][0] batch_normalization_4[0][0] __________________________________________________________________________________________________ re_lu_3 (ReLU) (None, 56, 56, 256) 0 add[0][0] __________________________________________________________________________________________________ conv2d_75 (Conv2D) (None, 56, 56, 64) 16448 re_lu_3[0][0] __________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None, 56, 56, 64) 256 conv2d_75[0][0] __________________________________________________________________________________________________ re_lu_4 (ReLU) (None, 56, 56, 64) 0 batch_normalization_5[0][0] __________________________________________________________________________________________________ conv2d_76 (Conv2D) (None, 56, 56, 64) 36928 re_lu_4[0][0] __________________________________________________________________________________________________ batch_normalization_6 (BatchNor (None, 56, 56, 64) 256 conv2d_76[0][0] __________________________________________________________________________________________________ re_lu_5 (ReLU) (None, 56, 56, 64) 0 batch_normalization_6[0][0] __________________________________________________________________________________________________ conv2d_77 (Conv2D) (None, 56, 56, 256) 16640 re_lu_5[0][0] __________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None, 56, 56, 256) 1024 conv2d_77[0][0] __________________________________________________________________________________________________ add_1 (Add) (None, 56, 56, 256) 0 batch_normalization_7[0][0] re_lu_3[0][0] __________________________________________________________________________________________________ re_lu_6 (ReLU) (None, 56, 56, 256) 0 add_1[0][0] __________________________________________________________________________________________________ conv2d_78 (Conv2D) (None, 56, 56, 64) 16448 re_lu_6[0][0] __________________________________________________________________________________________________ batch_normalization_8 (BatchNor (None, 56, 56, 64) 256 conv2d_78[0][0] __________________________________________________________________________________________________ re_lu_7 (ReLU) (None, 56, 56, 64) 0 batch_normalization_8[0][0] __________________________________________________________________________________________________ conv2d_79 (Conv2D) (None, 56, 56, 64) 36928 re_lu_7[0][0] __________________________________________________________________________________________________ batch_normalization_9 (BatchNor (None, 56, 56, 64) 256 conv2d_79[0][0] __________________________________________________________________________________________________ re_lu_8 (ReLU) (None, 56, 56, 64) 0 batch_normalization_9[0][0] __________________________________________________________________________________________________ conv2d_80 (Conv2D) (None, 56, 56, 256) 16640 re_lu_8[0][0] __________________________________________________________________________________________________ batch_normalization_10 (BatchNo (None, 56, 56, 256) 1024 conv2d_80[0][0] __________________________________________________________________________________________________ add_2 (Add) (None, 56, 56, 256) 0 batch_normalization_10[0][0] re_lu_6[0][0] __________________________________________________________________________________________________ re_lu_9 (ReLU) (None, 56, 56, 256) 0 add_2[0][0] __________________________________________________________________________________________________ conv2d_81 (Conv2D) (None, 56, 56, 128) 32896 re_lu_9[0][0] __________________________________________________________________________________________________ batch_normalization_11 (BatchNo (None, 56, 56, 128) 512 conv2d_81[0][0] __________________________________________________________________________________________________ re_lu_10 (ReLU) (None, 56, 56, 128) 0 batch_normalization_11[0][0] __________________________________________________________________________________________________ conv2d_82 (Conv2D) (None, 28, 28, 128) 147584 re_lu_10[0][0] __________________________________________________________________________________________________ batch_normalization_12 (BatchNo (None, 28, 28, 128) 512 conv2d_82[0][0] __________________________________________________________________________________________________ re_lu_11 (ReLU) (None, 28, 28, 128) 0 batch_normalization_12[0][0] __________________________________________________________________________________________________ conv2d_83 (Conv2D) (None, 28, 28, 512) 66048 re_lu_11[0][0] __________________________________________________________________________________________________ conv2d_84 (Conv2D) (None, 28, 28, 512) 131584 re_lu_9[0][0] __________________________________________________________________________________________________ batch_normalization_13 (BatchNo (None, 28, 28, 512) 2048 conv2d_83[0][0] __________________________________________________________________________________________________ batch_normalization_14 (BatchNo (None, 28, 28, 512) 2048 conv2d_84[0][0] __________________________________________________________________________________________________ add_3 (Add) (None, 28, 28, 512) 0 batch_normalization_13[0][0] batch_normalization_14[0][0] __________________________________________________________________________________________________ re_lu_12 (ReLU) (None, 28, 28, 512) 0 add_3[0][0] __________________________________________________________________________________________________ conv2d_85 (Conv2D) (None, 28, 28, 128) 65664 re_lu_12[0][0] __________________________________________________________________________________________________ batch_normalization_15 (BatchNo (None, 28, 28, 128) 512 conv2d_85[0][0] __________________________________________________________________________________________________ re_lu_13 (ReLU) (None, 28, 28, 128) 0 batch_normalization_15[0][0] __________________________________________________________________________________________________ conv2d_86 (Conv2D) (None, 28, 28, 128) 147584 re_lu_13[0][0] __________________________________________________________________________________________________ batch_normalization_16 (BatchNo (None, 28, 28, 128) 512 conv2d_86[0][0] __________________________________________________________________________________________________ re_lu_14 (ReLU) (None, 28, 28, 128) 0 batch_normalization_16[0][0] __________________________________________________________________________________________________ conv2d_87 (Conv2D) (None, 28, 28, 512) 66048 re_lu_14[0][0] __________________________________________________________________________________________________ batch_normalization_17 (BatchNo (None, 28, 28, 512) 2048 conv2d_87[0][0] __________________________________________________________________________________________________ add_4 (Add) (None, 28, 28, 512) 0 batch_normalization_17[0][0] re_lu_12[0][0] __________________________________________________________________________________________________ re_lu_15 (ReLU) (None, 28, 28, 512) 0 add_4[0][0] __________________________________________________________________________________________________ conv2d_88 (Conv2D) (None, 28, 28, 128) 65664 re_lu_15[0][0] __________________________________________________________________________________________________ batch_normalization_18 (BatchNo (None, 28, 28, 128) 512 conv2d_88[0][0] __________________________________________________________________________________________________ re_lu_16 (ReLU) (None, 28, 28, 128) 0 batch_normalization_18[0][0] __________________________________________________________________________________________________ conv2d_89 (Conv2D) (None, 28, 28, 128) 147584 re_lu_16[0][0] __________________________________________________________________________________________________ batch_normalization_19 (BatchNo (None, 28, 28, 128) 512 conv2d_89[0][0] __________________________________________________________________________________________________ re_lu_17 (ReLU) (None, 28, 28, 128) 0 batch_normalization_19[0][0] __________________________________________________________________________________________________ conv2d_90 (Conv2D) (None, 28, 28, 512) 66048 re_lu_17[0][0] __________________________________________________________________________________________________ batch_normalization_20 (BatchNo (None, 28, 28, 512) 2048 conv2d_90[0][0] __________________________________________________________________________________________________ add_5 (Add) (None, 28, 28, 512) 0 batch_normalization_20[0][0] re_lu_15[0][0] __________________________________________________________________________________________________ re_lu_18 (ReLU) (None, 28, 28, 512) 0 add_5[0][0] __________________________________________________________________________________________________ conv2d_91 (Conv2D) (None, 28, 28, 128) 65664 re_lu_18[0][0] __________________________________________________________________________________________________ batch_normalization_21 (BatchNo (None, 28, 28, 128) 512 conv2d_91[0][0] __________________________________________________________________________________________________ re_lu_19 (ReLU) (None, 28, 28, 128) 0 batch_normalization_21[0][0] __________________________________________________________________________________________________ conv2d_92 (Conv2D) (None, 28, 28, 128) 147584 re_lu_19[0][0] __________________________________________________________________________________________________ batch_normalization_22 (BatchNo (None, 28, 28, 128) 512 conv2d_92[0][0] __________________________________________________________________________________________________ re_lu_20 (ReLU) (None, 28, 28, 128) 0 batch_normalization_22[0][0] __________________________________________________________________________________________________ conv2d_93 (Conv2D) (None, 28, 28, 512) 66048 re_lu_20[0][0] __________________________________________________________________________________________________ batch_normalization_23 (BatchNo (None, 28, 28, 512) 2048 conv2d_93[0][0] __________________________________________________________________________________________________ add_6 (Add) (None, 28, 28, 512) 0 batch_normalization_23[0][0] re_lu_18[0][0] __________________________________________________________________________________________________ re_lu_21 (ReLU) (None, 28, 28, 512) 0 add_6[0][0] __________________________________________________________________________________________________ conv2d_94 (Conv2D) (None, 28, 28, 256) 131328 re_lu_21[0][0] __________________________________________________________________________________________________ batch_normalization_24 (BatchNo (None, 28, 28, 256) 1024 conv2d_94[0][0] __________________________________________________________________________________________________ re_lu_22 (ReLU) (None, 28, 28, 256) 0 batch_normalization_24[0][0] __________________________________________________________________________________________________ conv2d_95 (Conv2D) (None, 14, 14, 256) 590080 re_lu_22[0][0] __________________________________________________________________________________________________ batch_normalization_25 (BatchNo (None, 14, 14, 256) 1024 conv2d_95[0][0] __________________________________________________________________________________________________ re_lu_23 (ReLU) (None, 14, 14, 256) 0 batch_normalization_25[0][0] __________________________________________________________________________________________________ conv2d_96 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_23[0][0] __________________________________________________________________________________________________ conv2d_97 (Conv2D) (None, 14, 14, 1024) 525312 re_lu_21[0][0] __________________________________________________________________________________________________ batch_normalization_26 (BatchNo (None, 14, 14, 1024) 4096 conv2d_96[0][0] __________________________________________________________________________________________________ batch_normalization_27 (BatchNo (None, 14, 14, 1024) 4096 conv2d_97[0][0] __________________________________________________________________________________________________ add_7 (Add) (None, 14, 14, 1024) 0 batch_normalization_26[0][0] batch_normalization_27[0][0] __________________________________________________________________________________________________ re_lu_24 (ReLU) (None, 14, 14, 1024) 0 add_7[0][0] __________________________________________________________________________________________________ conv2d_98 (Conv2D) (None, 14, 14, 256) 262400 re_lu_24[0][0] __________________________________________________________________________________________________ batch_normalization_28 (BatchNo (None, 14, 14, 256) 1024 conv2d_98[0][0] __________________________________________________________________________________________________ re_lu_25 (ReLU) (None, 14, 14, 256) 0 batch_normalization_28[0][0] __________________________________________________________________________________________________ conv2d_99 (Conv2D) (None, 14, 14, 256) 590080 re_lu_25[0][0] __________________________________________________________________________________________________ batch_normalization_29 (BatchNo (None, 14, 14, 256) 1024 conv2d_99[0][0] __________________________________________________________________________________________________ re_lu_26 (ReLU) (None, 14, 14, 256) 0 batch_normalization_29[0][0] __________________________________________________________________________________________________ conv2d_100 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_26[0][0] __________________________________________________________________________________________________ batch_normalization_30 (BatchNo (None, 14, 14, 1024) 4096 conv2d_100[0][0] __________________________________________________________________________________________________ add_8 (Add) (None, 14, 14, 1024) 0 batch_normalization_30[0][0] re_lu_24[0][0] __________________________________________________________________________________________________ re_lu_27 (ReLU) (None, 14, 14, 1024) 0 add_8[0][0] __________________________________________________________________________________________________ conv2d_101 (Conv2D) (None, 14, 14, 256) 262400 re_lu_27[0][0] __________________________________________________________________________________________________ batch_normalization_31 (BatchNo (None, 14, 14, 256) 1024 conv2d_101[0][0] __________________________________________________________________________________________________ re_lu_28 (ReLU) (None, 14, 14, 256) 0 batch_normalization_31[0][0] __________________________________________________________________________________________________ conv2d_102 (Conv2D) (None, 14, 14, 256) 590080 re_lu_28[0][0] __________________________________________________________________________________________________ batch_normalization_32 (BatchNo (None, 14, 14, 256) 1024 conv2d_102[0][0] __________________________________________________________________________________________________ re_lu_29 (ReLU) (None, 14, 14, 256) 0 batch_normalization_32[0][0] __________________________________________________________________________________________________ conv2d_103 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_29[0][0] __________________________________________________________________________________________________ batch_normalization_33 (BatchNo (None, 14, 14, 1024) 4096 conv2d_103[0][0] __________________________________________________________________________________________________ add_9 (Add) (None, 14, 14, 1024) 0 batch_normalization_33[0][0] re_lu_27[0][0] __________________________________________________________________________________________________ re_lu_30 (ReLU) (None, 14, 14, 1024) 0 add_9[0][0] __________________________________________________________________________________________________ conv2d_104 (Conv2D) (None, 14, 14, 256) 262400 re_lu_30[0][0] __________________________________________________________________________________________________ batch_normalization_34 (BatchNo (None, 14, 14, 256) 1024 conv2d_104[0][0] __________________________________________________________________________________________________ re_lu_31 (ReLU) (None, 14, 14, 256) 0 batch_normalization_34[0][0] __________________________________________________________________________________________________ conv2d_105 (Conv2D) (None, 14, 14, 256) 590080 re_lu_31[0][0] __________________________________________________________________________________________________ batch_normalization_35 (BatchNo (None, 14, 14, 256) 1024 conv2d_105[0][0] __________________________________________________________________________________________________ re_lu_32 (ReLU) (None, 14, 14, 256) 0 batch_normalization_35[0][0] __________________________________________________________________________________________________ conv2d_106 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_32[0][0] __________________________________________________________________________________________________ batch_normalization_36 (BatchNo (None, 14, 14, 1024) 4096 conv2d_106[0][0] __________________________________________________________________________________________________ add_10 (Add) (None, 14, 14, 1024) 0 batch_normalization_36[0][0] re_lu_30[0][0] __________________________________________________________________________________________________ re_lu_33 (ReLU) (None, 14, 14, 1024) 0 add_10[0][0] __________________________________________________________________________________________________ conv2d_107 (Conv2D) (None, 14, 14, 256) 262400 re_lu_33[0][0] __________________________________________________________________________________________________ batch_normalization_37 (BatchNo (None, 14, 14, 256) 1024 conv2d_107[0][0] __________________________________________________________________________________________________ re_lu_34 (ReLU) (None, 14, 14, 256) 0 batch_normalization_37[0][0] __________________________________________________________________________________________________ conv2d_108 (Conv2D) (None, 14, 14, 256) 590080 re_lu_34[0][0] __________________________________________________________________________________________________ batch_normalization_38 (BatchNo (None, 14, 14, 256) 1024 conv2d_108[0][0] __________________________________________________________________________________________________ re_lu_35 (ReLU) (None, 14, 14, 256) 0 batch_normalization_38[0][0] __________________________________________________________________________________________________ conv2d_109 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_35[0][0] __________________________________________________________________________________________________ batch_normalization_39 (BatchNo (None, 14, 14, 1024) 4096 conv2d_109[0][0] __________________________________________________________________________________________________ add_11 (Add) (None, 14, 14, 1024) 0 batch_normalization_39[0][0] re_lu_33[0][0] __________________________________________________________________________________________________ re_lu_36 (ReLU) (None, 14, 14, 1024) 0 add_11[0][0] __________________________________________________________________________________________________ conv2d_110 (Conv2D) (None, 14, 14, 256) 262400 re_lu_36[0][0] __________________________________________________________________________________________________ batch_normalization_40 (BatchNo (None, 14, 14, 256) 1024 conv2d_110[0][0] __________________________________________________________________________________________________ re_lu_37 (ReLU) (None, 14, 14, 256) 0 batch_normalization_40[0][0] __________________________________________________________________________________________________ conv2d_111 (Conv2D) (None, 14, 14, 256) 590080 re_lu_37[0][0] __________________________________________________________________________________________________ batch_normalization_41 (BatchNo (None, 14, 14, 256) 1024 conv2d_111[0][0] __________________________________________________________________________________________________ re_lu_38 (ReLU) (None, 14, 14, 256) 0 batch_normalization_41[0][0] __________________________________________________________________________________________________ conv2d_112 (Conv2D) (None, 14, 14, 1024) 263168 re_lu_38[0][0] __________________________________________________________________________________________________ batch_normalization_42 (BatchNo (None, 14, 14, 1024) 4096 conv2d_112[0][0] __________________________________________________________________________________________________ add_12 (Add) (None, 14, 14, 1024) 0 batch_normalization_42[0][0] re_lu_36[0][0] __________________________________________________________________________________________________ re_lu_39 (ReLU) (None, 14, 14, 1024) 0 add_12[0][0] __________________________________________________________________________________________________ conv2d_113 (Conv2D) (None, 14, 14, 512) 524800 re_lu_39[0][0] __________________________________________________________________________________________________ batch_normalization_43 (BatchNo (None, 14, 14, 512) 2048 conv2d_113[0][0] __________________________________________________________________________________________________ re_lu_40 (ReLU) (None, 14, 14, 512) 0 batch_normalization_43[0][0] __________________________________________________________________________________________________ conv2d_114 (Conv2D) (None, 7, 7, 512) 2359808 re_lu_40[0][0] __________________________________________________________________________________________________ batch_normalization_44 (BatchNo (None, 7, 7, 512) 2048 conv2d_114[0][0] __________________________________________________________________________________________________ re_lu_41 (ReLU) (None, 7, 7, 512) 0 batch_normalization_44[0][0] __________________________________________________________________________________________________ conv2d_115 (Conv2D) (None, 7, 7, 2048) 1050624 re_lu_41[0][0] __________________________________________________________________________________________________ conv2d_116 (Conv2D) (None, 7, 7, 2048) 2099200 re_lu_39[0][0] __________________________________________________________________________________________________ batch_normalization_45 (BatchNo (None, 7, 7, 2048) 8192 conv2d_115[0][0] __________________________________________________________________________________________________ batch_normalization_46 (BatchNo (None, 7, 7, 2048) 8192 conv2d_116[0][0] __________________________________________________________________________________________________ add_13 (Add) (None, 7, 7, 2048) 0 batch_normalization_45[0][0] batch_normalization_46[0][0] __________________________________________________________________________________________________
re_lu_42 (ReLU) (None, 7, 7, 2048) 0 add_13[0][0] __________________________________________________________________________________________________ conv2d_117 (Conv2D) (None, 7, 7, 512) 1049088 re_lu_42[0][0] __________________________________________________________________________________________________ batch_normalization_47 (BatchNo (None, 7, 7, 512) 2048 conv2d_117[0][0] __________________________________________________________________________________________________ re_lu_43 (ReLU) (None, 7, 7, 512) 0 batch_normalization_47[0][0] __________________________________________________________________________________________________ conv2d_118 (Conv2D) (None, 7, 7, 512) 2359808 re_lu_43[0][0] __________________________________________________________________________________________________ batch_normalization_48 (BatchNo (None, 7, 7, 512) 2048 conv2d_118[0][0] __________________________________________________________________________________________________ re_lu_44 (ReLU) (None, 7, 7, 512) 0 batch_normalization_48[0][0] __________________________________________________________________________________________________ conv2d_119 (Conv2D) (None, 7, 7, 2048) 1050624 re_lu_44[0][0] __________________________________________________________________________________________________ batch_normalization_49 (BatchNo (None, 7, 7, 2048) 8192 conv2d_119[0][0] __________________________________________________________________________________________________ add_14 (Add) (None, 7, 7, 2048) 0 batch_normalization_49[0][0] re_lu_42[0][0] __________________________________________________________________________________________________ re_lu_45 (ReLU) (None, 7, 7, 2048) 0 add_14[0][0] __________________________________________________________________________________________________ conv2d_120 (Conv2D) (None, 7, 7, 512) 1049088 re_lu_45[0][0] __________________________________________________________________________________________________ batch_normalization_50 (BatchNo (None, 7, 7, 512) 2048 conv2d_120[0][0] __________________________________________________________________________________________________ re_lu_46 (ReLU) (None, 7, 7, 512) 0 batch_normalization_50[0][0] __________________________________________________________________________________________________ conv2d_121 (Conv2D) (None, 7, 7, 512) 2359808 re_lu_46[0][0] __________________________________________________________________________________________________ batch_normalization_51 (BatchNo (None, 7, 7, 512) 2048 conv2d_121[0][0] __________________________________________________________________________________________________ re_lu_47 (ReLU) (None, 7, 7, 512) 0 batch_normalization_51[0][0] __________________________________________________________________________________________________ conv2d_122 (Conv2D) (None, 7, 7, 2048) 1050624 re_lu_47[0][0] __________________________________________________________________________________________________ batch_normalization_52 (BatchNo (None, 7, 7, 2048) 8192 conv2d_122[0][0] __________________________________________________________________________________________________ add_15 (Add) (None, 7, 7, 2048) 0 batch_normalization_52[0][0] re_lu_45[0][0] __________________________________________________________________________________________________ re_lu_48 (ReLU) (None, 7, 7, 2048) 0 add_15[0][0] __________________________________________________________________________________________________ global_average_pooling2d_1 (Glo (None, 2048) 0 re_lu_48[0][0] __________________________________________________________________________________________________ dense_8 (Dense) (None, 1000) 2049000 global_average_pooling2d_1[0][0] ================================================================================================== Total params: 25,636,712 Trainable params: 25,583,592 Non-trainable params: 53,120 __________________________________________________________________________________________________

MobileNet (2017)

MobileNet はモデルサイズの軽量化を図りながら、高精度の予測を可能としたモデルです。物体検出などの速度が求められる問題設定で、バックボーンとして使用されています。

Depthwise Separable Convolution

モデルサイズの軽量化に大きく貢献した工夫として、Depthwise Separable Convolution があります。畳み込みの計算を以下の 2 つに分解することで、通常の畳み込み処理から大きくパラメータ数を削減することに成功しました。

  • Depthwise Convolution
  • Pointwise Convolution

37_8

出典:https://arxiv.org/pdf/1704.04861.pdf

Depthwise Convolution

3 チャネルの画像に対し、各チャネルごとに 1 枚ずつフィルタを用意し、チャネル単位で畳み込みを行います。この際、通常の畳み込みとは異なり、各チャネルでの計算結果を足し合わせる処理は行いません。

TensorFlow では layers.DepthWiseConv2D() で実装することができます。

Pointwise Convolution

Depthwise Convolution では入力画像のチャネル数分のチャネルを持った特徴マップが出力されます。

それを、1×11\times1 のフィルタを出力特徴マップのチャネル数分用意して、畳み込み計算を行うのが Pointwise Convolution です。チャネル方向の畳み込みと覚えてください。

TensorFlow では、layers.Conv2D(kernel_size=1, strides=1) で実装することができます。

通常の畳み込み層とのパラメータ数の比較

  • 入力のチャネル数:3
  • 出力のチャネル数:10

上記の条件の際、通常の畳み込み層においてバイアスを除くパラメータは、3×3×3×10=2703 \times 3 \times 3 \times 10 = 270 個となります。

37_9

一方、Depthwise Separable Convolution の場合

  • Depthwise:3×3×3=273 \times 3 \times 3 = 27
  • Pointwise:1×1×3×10=301 \times 1 \times 3 \times 10 = 30
  • 合計:27+30=5727 + 30 = 57

となります。

37_10

def mobilenet(input_shape, n_classes):
  
    def mobilenet_block(x, f, strides=1):
        x = layers.DepthwiseConv2D(3, strides=strides, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

        x = layers.Conv2D(f, 1, strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

        return x

    input = layers.Input(input_shape)

    x = layers.Conv2D(32, 3, strides=2, padding='same')(input)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = mobilenet_block(x, 64)
    x = mobilenet_block(x, 128, 2)
    x = mobilenet_block(x, 128)

    x = mobilenet_block(x, 256, 2)
    x = mobilenet_block(x, 256)

    x = mobilenet_block(x, 512, 2)
    for _ in range(5):
        x = mobilenet_block(x, 512)

    x = mobilenet_block(x, 1024, 2)
    x = mobilenet_block(x, 1024)

    x = layers.GlobalAveragePooling2D()(x)

    output = layers.Dense(n_classes, activation='softmax')(x)

    model = models.Model(input, output)

    return model
      
model = mobilenet((224, 224, 3), 1000)
model.summary()
      
Model: "model_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 224, 224, 3)] 0 _________________________________________________________________ conv2d_123 (Conv2D) (None, 112, 112, 32) 896 _________________________________________________________________ batch_normalization_53 (Batc (None, 112, 112, 32) 128 _________________________________________________________________ re_lu_49 (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ depthwise_conv2d (DepthwiseC (None, 112, 112, 32) 320 _________________________________________________________________ batch_normalization_54 (Batc (None, 112, 112, 32) 128 _________________________________________________________________ re_lu_50 (ReLU) (None, 112, 112, 32) 0 _________________________________________________________________ conv2d_124 (Conv2D) (None, 112, 112, 64) 2112 _________________________________________________________________ batch_normalization_55 (Batc (None, 112, 112, 64) 256 _________________________________________________________________ re_lu_51 (ReLU) (None, 112, 112, 64) 0 _________________________________________________________________ depthwise_conv2d_1 (Depthwis (None, 56, 56, 64) 640 _________________________________________________________________ batch_normalization_56 (Batc (None, 56, 56, 64) 256 _________________________________________________________________ re_lu_52 (ReLU) (None, 56, 56, 64) 0 _________________________________________________________________ conv2d_125 (Conv2D) (None, 56, 56, 128) 8320 _________________________________________________________________ batch_normalization_57 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ re_lu_53 (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ depthwise_conv2d_2 (Depthwis (None, 56, 56, 128) 1280 _________________________________________________________________ batch_normalization_58 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ re_lu_54 (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ conv2d_126 (Conv2D) (None, 56, 56, 128) 16512 _________________________________________________________________ batch_normalization_59 (Batc (None, 56, 56, 128) 512 _________________________________________________________________ re_lu_55 (ReLU) (None, 56, 56, 128) 0 _________________________________________________________________ depthwise_conv2d_3 (Depthwis (None, 28, 28, 128) 1280 _________________________________________________________________ batch_normalization_60 (Batc (None, 28, 28, 128) 512 _________________________________________________________________ re_lu_56 (ReLU) (None, 28, 28, 128) 0 _________________________________________________________________ conv2d_127 (Conv2D) (None, 28, 28, 256) 33024 _________________________________________________________________ batch_normalization_61 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ re_lu_57 (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ depthwise_conv2d_4 (Depthwis (None, 28, 28, 256) 2560 _________________________________________________________________ batch_normalization_62 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ re_lu_58 (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ conv2d_128 (Conv2D) (None, 28, 28, 256) 65792 _________________________________________________________________ batch_normalization_63 (Batc (None, 28, 28, 256) 1024 _________________________________________________________________ re_lu_59 (ReLU) (None, 28, 28, 256) 0 _________________________________________________________________ depthwise_conv2d_5 (Depthwis (None, 14, 14, 256) 2560 _________________________________________________________________ batch_normalization_64 (Batc (None, 14, 14, 256) 1024 _________________________________________________________________ re_lu_60 (ReLU) (None, 14, 14, 256) 0 _________________________________________________________________ conv2d_129 (Conv2D) (None, 14, 14, 512) 131584 _________________________________________________________________ batch_normalization_65 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_61 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_6 (Depthwis (None, 14, 14, 512) 5120 _________________________________________________________________ batch_normalization_66 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_62 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_130 (Conv2D) (None, 14, 14, 512) 262656 _________________________________________________________________ batch_normalization_67 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_63 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_7 (Depthwis (None, 14, 14, 512) 5120 _________________________________________________________________ batch_normalization_68 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_64 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_131 (Conv2D) (None, 14, 14, 512) 262656 _________________________________________________________________ batch_normalization_69 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_65 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_8 (Depthwis (None, 14, 14, 512) 5120 _________________________________________________________________ batch_normalization_70 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_66 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_132 (Conv2D) (None, 14, 14, 512) 262656 _________________________________________________________________ batch_normalization_71 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_67 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_9 (Depthwis (None, 14, 14, 512) 5120 _________________________________________________________________ batch_normalization_72 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_68 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_133 (Conv2D) (None, 14, 14, 512) 262656 _________________________________________________________________ batch_normalization_73 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_69 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_10 (Depthwi (None, 14, 14, 512) 5120 _________________________________________________________________ batch_normalization_74 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_70 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ conv2d_134 (Conv2D) (None, 14, 14, 512) 262656 _________________________________________________________________ batch_normalization_75 (Batc (None, 14, 14, 512) 2048 _________________________________________________________________ re_lu_71 (ReLU) (None, 14, 14, 512) 0 _________________________________________________________________ depthwise_conv2d_11 (Depthwi (None, 7, 7, 512) 5120 _________________________________________________________________ batch_normalization_76 (Batc (None, 7, 7, 512) 2048 _________________________________________________________________ re_lu_72 (ReLU) (None, 7, 7, 512) 0 _________________________________________________________________ conv2d_135 (Conv2D) (None, 7, 7, 1024) 525312 _________________________________________________________________ batch_normalization_77 (Batc (None, 7, 7, 1024) 4096 _________________________________________________________________ re_lu_73 (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ depthwise_conv2d_12 (Depthwi (None, 7, 7, 1024) 10240 _________________________________________________________________ batch_normalization_78 (Batc (None, 7, 7, 1024) 4096 _________________________________________________________________ re_lu_74 (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ conv2d_136 (Conv2D) (None, 7, 7, 1024) 1049600 _________________________________________________________________ batch_normalization_79 (Batc (None, 7, 7, 1024) 4096 _________________________________________________________________ re_lu_75 (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ global_average_pooling2d_2 ( (None, 1024) 0 _________________________________________________________________ dense_9 (Dense) (None, 1000) 1025000 ================================================================= Total params: 4,264,808 Trainable params: 4,242,920 Non-trainable params: 21,888 _________________________________________________________________

この流れのように、CNN はパラメータ数を削減しながら精度を向上させるように発展が進められてきました。現在でも活発に研究が進められているため、今後もどのようなモデルが発表されるのか楽しみです。

本章で紹介したモデルは、有名なモデルの一部ではありますが、最も重要なアーキテクチャがたくさん入っているので紹介しました。モデルの内部をすべて覚える必要はありませんが、各モデルがどんな特徴を持っているかはなんとなく覚えておくと、これからコンピュータビジョンの分野を進んでいく上で強い武器になります。これからも学び続けていきましょう。

shareアイコン