データ拡張

ディープラーニングの世界では大量のデータ拡張が必要であることが前提となるため、欲しいデータが潤沢にあれば良いのですが、実現場ではなかなか求めているデータ数を集めることは思ったようにいかないケースが多いです。

そういった場合に、学習データ拡張の画像に対して移動、回転、拡大・縮小など人工的な操作を加えることでデータ拡張数を水増しするテクニックがあります。水増しされることで同じ画像が学習されることが少なくなるので汎化性能が向上されることが期待されます。

本章では、水増しテクニックである データ拡張 (Data Augmentation) の代表的な処理を確認したうえで、適用前後で精度がどのように変化するかを確認します。

PyTorch では、torchvision.transforms に様々な水増しのメソッドが用意されているため、簡単に実装が可能です。

代表的な処理として、以下があげられますのでそちらを順番に実装していきます。

  • 回転
  • 水平移動
  • せん断
  • 拡大
  • 水平反転
  • 垂直反転

本章の流れ

  • ベースモデルの作成
  • 各処理の確認
  • 各処理適用後の画像を保存
  • データによる精度の確認

ベースモデルの作成

前章と同じように、まずはベースモデルを作成しましょう。

!pip install pytorch_lightning
      
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer
      
pl.__version__
      
'0.7.1'

データセットの準備

本章でも CIFAR10 と呼ばれる 10 クラス分類を行います。torchvision にデータセットが用意されています。

# 前処理
transform = transforms.Compose([
    transforms.ToTensor(),
])
      
# データの取得と分割
train_val = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

# train : val = 0.8 : 0.2
n_train = int(len(train_val) * 0.8)
n_val = len(train_val) - n_train

# ランダムに分割を行うため、シードを固定して再現性を確保
torch.manual_seed(0)

# train と val に分割
train, val = torch.utils.data.random_split(train_val, [n_train, n_val])
      

それでは、今回扱うデータを 25 枚ランダムに抜粋して表示します。

正解ラベル 種別
0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

10 クラス分類となっており、上記の表の種別を分類することが目標です。

#画像の表示
plt.figure(figsize=(12,12))
for i in range(25):
    img = np.transpose(train[i][0].numpy(), (1, 2, 0))
    plt.subplot(5, 5, i+1)
    plt.imshow(img)
      
<Figure size 864x864 with 25 Axes>

モデルの定義と学習

今回は、データ拡張の効果検証で学習データと検証データの正解率を比較します。そのため、TrainNet クラスに正解率を算出するスクリプトを追記しています。

class TrainNet(pl.LightningModule):

    @pl.data_loader
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train, self.batch_size, shuffle=True)

    def training_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        # 追加
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'loss': loss, 'acc': acc}
        return results
      
class ValidationNet(pl.LightningModule):

    @pl.data_loader
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val, self.batch_size)

    def validation_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'val_loss': loss, 'val_acc': acc}
        return results

    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        results =  {'val_loss': avg_loss, 'val_acc': avg_acc}
        return results
      
class TestNet(pl.LightningModule):

    @pl.data_loader
    def test_dataloader(self):
        return torch.utils.data.DataLoader(test, self.batch_size)

    def test_step(self, batch, batch_nb):
        x, t = batch
        y = self.forward(x)
        loss = self.lossfun(y, t)
        y_label = torch.argmax(y, dim=1)
        acc = torch.sum(t == y_label) * 1.0 / len(t)
        results = {'test_loss': loss, 'test_acc': acc}
        return results

    def test_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        results = {'test_loss': avg_loss, 'test_acc': avg_acc}
        return results
      
class Net(TrainNet, ValidationNet, TestNet):

    def __init__(self, batch_size=128):
        super(Net, self).__init__()
        self.batch_size = batch_size
        # 畳み込み層
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        # 全結合層
        self.fc1 = nn.Linear(128*4*4, 128)
        self.fc2 = nn.Linear(128, 10)

    def lossfun(self, y, t):
        return F.cross_entropy(y, t)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        # ch: 3 -> 32, size: 32 * 32 -> 16 * 16
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)

        # ch: 32 -> 64, size: 16 * 16 -> 8 * 8
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)

        # ch: 64 -> 128, size: 8 * 8 -> 4 * 4
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
      
# 乱数のシードを固定
torch.manual_seed(0)

# ネットワーク学習の準備
net = Net(batch_size=1024)
trainer = Trainer(gpus=1, max_epochs=50, early_stop_callback=False)
      
trainer.fit(net)
      
trainer.test()
trainer.callback_metrics
      
HBox(children=(IntProgress(value=0, description='Testing', layout=Layout(flex='2'), max=10, style=ProgressStyl…
---------------------------------------------------------------------------------------------------- TEST RESULTS {} ----------------------------------------------------------------------------------------------------
{'loss': 0.5888516306877136, 'acc': 0.796875, 'val_loss': 0.9243364334106445, 'val_acc': 0.7051758170127869, 'epoch': 49, 'test_loss': 0.9193058013916016, 'test_acc': 0.7019471526145935}
Train Val Test
Base Accuracy 0.797 0.705 0.702
Base Loss 0.589 0.924 0.919

上記のスコアをベースラインとして、データを適用することで汎化性能が向上するか確認しましょう。

各処理の確認

具体的に適用する前に、代表的な水増し処理を確認します。まずは、CIFAR10 の画像を PIL の型に変換し、サンプルイメージとして用意しましょう。

transforms による変換は、Pillow のデータ形式の入力を前提としています。

# 画像の読み込みと PIL に変換
img = train[100][0]
ToPIL = transforms.ToPILImage()
img = ToPIL(img)
      

また、各処理の変換前と変換後を比較し、表示する関数を事前に作成しておきましょう。何度も使う処理は先に関数化しておくと、使い回せるため便利です。

# 各処理の変換前後を表示
def show(in_img, out_img):
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.title('before')
    plt.imshow(in_img)
    
    plt.subplot(1, 2, 2)
    plt.title('after')
    plt.imshow(out_img)
      

回転

transforms.RandomRotation で変換できます。degrees 引数で回転の度合いを指定します。

transform = transforms.RandomRotation(degrees=30)
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

水平移動

transforms.RandomAffine() で変換できます。translate 引数で縦横方向への移動の幅の指定ができます。

例えば 32 ピクセルの正方形の画像で、translate=(0.5, 0) とした場合、縦方向に -16 ~ 16 の幅でランダムに水平移動します。

transform = transforms.RandomAffine(degrees=0, translate=(0.5, 0.5))
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

せん断

せん断は、四角形の画像を平行四辺形に変形する処理です。transforms.RandomAffine() で変換できます。shear 引数で縦、横方向のせん断の度合いを指定できます。

transform = transforms.RandomAffine(degrees=0, translate=(0, 0), shear=(0, 30))
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

拡大

transforms.RandomCrop で変換します。引数にはクロップする縦横の画素数を指定します。

transform = transforms.RandomCrop((16, 16))
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

水平反転

transforms.RandomHorizontalFlip で変換します。引数 p には、反転を起こす確率を与えます。もしも 1 とすると、100% 水平に反転するということを表します。

transform = transforms.RandomHorizontalFlip(p=1) 
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

垂直反転

transforms.RandomVerticalFlip で変換します。引数 p には、反転を起こす確率を与えます。もしも 1 とすると、100% 垂直に反転するということを表します。

transform = transforms.RandomVerticalFlip(p=1) 
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

option : fillcolor

回転や水平移動、せん断で生じる空白箇所は、デフォルトでは輝度が 0(黒色)で埋められます。

空白箇所を任意の色で塗りつぶしたい場合は、fillcolor 引数で指定します。カラー画像の場合、RGB の輝度をタプルで与えます。

#水平移動し、空白はグレーで埋める
transform = transforms.RandomAffine(degrees=0, translate=(0.5, 0.5), fillcolor=(100, 100, 100))
out = transform(img)

show(img, out)
      
<Figure size 720x720 with 2 Axes>

各処理適用後の画像を保存

処理適用をすると、適用後の画像は保存した上で改めて学習をおこなうことが多々あります。その場合に、処理適用後の画像を保存するには img.save('保存したい名前') とします。今回は out.jpg としましょう。

# RandomCrop 処理をかける
transform = transforms.RandomCrop((16, 16))
out = transform(img)

# 画像の保存
out.save('out.jpg')
      

データ拡張による精度の変化

データ拡張は、適用する処理を明示的に与えるわけではなく、複数の処理を transforms.Compose() 内に宣言することでランダムに処理が選ばれて適用されます。例えば、拡大と水平変換とせん断を入れたとしても、毎回すべてが適用されるわけではなく、上手く選ばれながら画像を増やしていきます。

今回は水平変換と垂直変換を入れて、汎化性能が向上するか確認しましょう。

# 前処理
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
      
# データの分割
train_val = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

# train : val = 0.8 : 0.2
n_train = int(len(train_val) * 0.8)
n_val = len(train_val) - n_train

# ランダムに分割を行うため、シードを固定して再現性を確保
torch.manual_seed(0)

# train と val に分割
train, val = torch.utils.data.random_split(train_val, [n_train, n_val])
      

モデルの学習

# 乱数のシードを固定
torch.manual_seed(0)

# ネットワーク学習の準備
net = Net(batch_size=1024)
trainer = Trainer(gpus=1, max_epochs=50, early_stop_callback=False)
      
trainer.fit(net)
      
trainer.test()
trainer.callback_metrics
      
HBox(children=(IntProgress(value=0, description='Testing', layout=Layout(flex='2'), max=10, style=ProgressStyl…
---------------------------------------------------------------------------------------------------- TEST RESULTS {} ----------------------------------------------------------------------------------------------------
{'loss': 0.6235649585723877, 'acc': 0.75, 'val_loss': 0.815513551235199, 'val_acc': 0.7154735922813416, 'epoch': 49, 'test_loss': 0.8038299679756165, 'test_acc': 0.7226323485374451}
Train Val Test
Base Accuracy 0.797 0.705 0.702
Base Loss 0.589 0.924 0.919
Augmentation Accuracy 0.75 0.715 0.723
Augmentation Loss 0.624 0.816 0.804

検証、テストデータの正解率が向上し、学習データの正解率との乖離が小さくなりました。汎化性能が向上したことが確認できました。

データ拡張は、簡単な処理でありながら手軽に精度向上に貢献してくれる重要な手法です。最近では画像認識分野でスタンダードな前手法になっていますので、ぜひ適用してください。

shareアイコン