Vision transformer (ViT)を用いた画像認識のコード解説。初心者向 けにtensorflow keras APIのコードをわかりやすく解説します。

データサイエンス

 数年前までは、画像認識のアルゴリズムといえば、CNN(Convolutional Neural Network)が周流でありましたが、最近Vision Transformerを用いた事例が増えつつあり、精度が高い結果も見受けられます。

今回は初心者の方向けに、Vision Transformerのコードわかりやすく日本語で解説します。

Vision Transformerとは

 Vision Transformerとは、これまで主に自然言語処理のAIアルゴリズムの一部として、用いられてきた、Transoformerという手法を、画像認識に適用したアルゴリズムです。

数年前までは、画像認識といえば畳み込み(convolution)を用いたCNNが主流でしたが、最近では、Vision Transformerが有力視されています。

サンプルコード

 今回解説するのは、tensorflow kerasのオフィシャルページに掲載されている、サンプルコードです。比較的シンプルで、理解しやすいので初心者の方はぜひ参考になさってください。

Keras documentation: Image classification with Vision Transformer
Keras documentation

Introduction 準備

ツールをインストールします。tensorflowとtensorflowのアドオンです。

pip install tensrflow
pip install -U tensorflow-addons

Setup インポート

ツール関係のインポートです。numpy、tensorflow、keras、tensorflow、tensorflow_addonsの5種類です。

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

Prepare the data データの準備

cifar100のデータセットを使います。100クラスの画像分類のデータセットで、画像サイズは、縦32pixel、横32pixel、3chのデータです。下記コードを実行するとデータセットのダウンロードが始まりますので、ダウンロード完了までしばらくお待ちください。

num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

Configure the hyperparameters ハイパーパラメータの設定

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

ハイパーパラメータの設定です。精度を高めるためには、調整が必要になりますが、まずは、なんとなく理解しておきましょう。

  • learning_rate:学習率
  • weight_decay:学習率の調整パラメータ
  • batch_size:バッチサイズ(一度に学習させる画像のデータサイズ)
  • num_epoch:エポック数
  • image_size:72pixel(データセットの画像は32pixelですがリサイズして使います)
  • pathc_size:6pixel(画像を細かく分けるパッチのサイズです)
  • num_patches:画像一枚あたりのパッチ数(画像は縦横72pixelなので、6×6pixelのパッチが、12×12=144個できます)
  • projection_dim:4
  • transformer_units:[8, 4]トランスフォーマーレイヤーのサイズです
  • transformer_layer:トランスフォーマの層の数
  • mlp_head_units:[2048, 1024] (mlp(multi layer perceptron) 全結合層のサイズ)

Use data augmentation データ拡張、水増しを使うところ

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

画像のデータ拡張です。

  • layers.Normalization():バッチごとに正規化を行います
  • layers.Resizing():画像のリサイズ
  • layers.RandomFlip(“horizontal”):画像を水平方向に反転させます。反転させるのは、ランダムに発生
  • layers.RondomRotation(factor=0.02):画像をランダムに回転させます。回転角はプラスマイナス0.02ラジアン内でランダムに。
  • layres.RondomZoom:画像をランダムにズーム、高さ方向20%、幅方向20%拡大

???最終行で、学習データx_trainにデータ拡張を設定しています。

Implement multilayer perceptron (MLP) 全結合層を埋め込む関数

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

MLP(multilayer perceptron)、全結合層の関数を作っておきます。

hidden_unitsは隠れ層の数、dropout_rateはドロップアウト率です。

ここは、単純な全結合層です。

初心者の方は、掛け算と足し算をなんども行っているところとくらいに覚えておいてください。

Implement patch creation as a layer パッチクリエーション層を埋め込むクラス

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

Let’s display patches for a sample image 一枚のサンプル画像でパッチを表示してみよう!

パッチに切り刻んだ様子を表示してみます。

import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

3行目で、ランダムに学習画像を1枚選びます。

9行目で、先ほど作成したPathcesクラスをつかって、patchesにパッチ画像作成します。

16~20行目でmatplotを使って、パッチ画像を表示します。下記は表示例です。乱数を使って一枚選ぶので、実行するたびに変わります。

Implement the patch encoding layer パッチエンコーディング層の埋め込み

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

パッチエンコーディング層といViTの構造の一部を作成します。

全結合層と、embedding層で作成されています。

初めは、モデルの構造を定義しているんだな~くらいの理解でOKです。

Build the ViT model ViTモデルを構築します

def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

これまでに作成した、関数や、クラスを活用しながら、ViTモデルを完成させます。下記に概要を記載しておりますが、細部の理解は、ある程度慣れてきてからにしましょう。

  • 2行目 input 入力層
  • 4行目 augmented:データ拡張
  • 6行目 patches:画像データを切り刻んで、パッチデータを作成します
  • 8行目 encoded_patches:PatchEncoderクラスを使ってパッチエンコーディング
  • 11行目 x1:正規化
  • 13~15行目 attention層
  • 18行目 skip conection
  • 21行目 正規化
  • 23行目 MLP(Multi layer perceptron)関数を使って 
  • 25行目 skip conection 2回目
  • 27~29行目 出力の形を整える
  • 31行目 MLPを追加する

Compile, train, and evaluate the model モデルとコンパイルと、学習と、評価

def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)

学習の準備と、実行です。

  • 2~4行目 optimizer:最適化の設定を行います。weight decay付きのAdamを使っています。
  • 5~12行目 モデルのコンパイルです。コンパイルは、学習の諸条件の設定と思ってください。optimizer、loss(損失関数)、metrics(精度)の計算方法を設定しています。精度は、クラスの予測精度と、top5に入ったどうかの予測精度2つを設定しています。後者のイメージは、画像一枚につき、100クラスのどれに入りそうか選んだものの中で、上位5クラスに入るかどうかです。
  • 13~19行目 モデルのチェックポイントの設定です。学習中にモデルを保存する条件を設定しています。valデータでベストを更新したときに保存するような設定にしていますね。
  • 20~26行目 モデルの学習を実行しています。学習データの画像、学習データのラベル、バッチサイズ、エポック数、検証データに回すデータの比率を引数で設定しています。
  • 27~30行目 学習したモデルを読み込んで、テストデータで精度を確認します
  • 最後の2行 自作した関数を用いてViTモデルを構築して、自作した関数を用いて、学習検証を行っています。

100epochの学習で、精度55%、top5の予測精度85%達成します。

今回のサンプルコードは、シンプルなものです、より高い精度を狙うためには、

  • epoch数を増やす
  • トランスフォーマ層を増やす
  • 入力画像をリサイズする(大きくした方がいいですかね)
  • パッチサイズを変える(同じことですが、projection_dimentionを変える)
  • 学習率、optimizer、weight_decayなども影響するようです。

より確実に精度を上げるためには、大規模な高解像度データでプリトレーニング済みのモデルのファインチューニングを行うことを進めています。

まとめ

 画像認識のアルゴリズムとして、定着しつつあるVision Transformerのサンプルコードを、初心者向けにわかりやすく解説しました。

 深層学習のアルゴリズムは、年々新しいものが発表されてフォローするのが大変ですが、なるべく新しくて精度のよいものを使いこなせるようにしましょう。

タイトルとURLをコピーしました