Facebookが公開している、VisionTransformeモデル、DeiTのサンプルコードを、初心者の方向けに解説します。細部の理論はとりあえずおいて、動作することを目標に解説しますので、ご活用ください。
DeiTとは
DeiT(Data-efficient Image Transformers)とは、Facebookが公開している、Vision Transformerという種類分される、画像分類のアルゴリズムです。
画像分類のアルゴリズムといえば、長年CNN(Convolutional Neural Network)が主流でしたが、精度面で、最近では、Vision Transformerに置き換わりつつあります。
本サンプルプログラムは、モバイル端末への実装も考慮した軽量化について説明されておりますので、特にモバイルへ実装される方はご活用ください。
オフィシャルページのサンプルプログラムこちらになります。
ちなみに、DeiTは、データ拡張に工夫をすることで、CNNにくらべて、少ないデータ数で学習可能で、CNNを教師とした蒸留行っているそうです。慣れてきたら詳しく学んでください。
サンプルコード解説
準備(インストール)
ツール関係をインストールします。pipでtorch、torchvision、tim、pandas、requestsを入れます
pip install torch torchvision timm pandas requests
google colaboで試す方はこちらです。先頭に!です。
!pip install timm pandas requests
学習済みのDieTで画像分類実行(PCで)
下記は、学習済みのDieTモデルで画像分類をPCで実行するコードです。
シンプルで、簡単ですね。
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
print(torch.__version__)
# should be 1.8.0
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
transform = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
- 1行目、画像処理ライブラリ Pillowをインポート。
- 2行目Pytorchライブラリtorchをインポート。
- 3行目、Pytorchの画像処理関係のライブラリ、torchvisionをインポート。
- 4行目HTTP通信用のライブラリrequestsをインポート。
- 5行目、6行目、その他細かいライブラリをインポート
- 12行目フェイスブックが用意してくれている、学習済みモデルをロードします。
- 13行目推論用のモデルに切り替えます
- 15~20行目 前処理に設定です。
- 16行目はリサイズの設定です。引数interpolationは、リサイズ時の補正アルゴル選択です。3が精度のいいやつです。
- 17行目は、クロップ(切り取り)設定です。中心224×224を切り取ります。
- 18行目は、画像データを、Pytorchのtensor形式に変換します
- 19行目は、データの正規化です。データベース内の画像の平均値と、標準偏差を使って正規化します。
- 22行目で、Pilowで画像を一枚読み込みます。縦224pix、横224pix、3chの画像データです。
- 23行目で、読み込んだ1枚の画像に対して前処理を実行します。
- 24行目で、modelを使って画像認識を行います。1000クラスのそれぞれの確率がでます。(イメージネットなので、1000クラス分類ですね)
- 25行目で、確率最大の番号を調べます。269番目です。「timber wolf, grey wolf, gray wolf, Canis lupus’」みたいです。。。
スクリプティング DeiT(モバイルで実行できる形へ変換)
モバイルで使用する場合は、TorchScript 形式に変換します。拡張子 .ptのファイルに変換です。
さすが、フェイスブック簡単にモバイルで実行できる形に変換できますね!
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")
- 1行目、フェイスブックが公開している学習済みモデルをロードします。
- 2行目、モデルを推論モードに切り替えます
- 3行目、モデルのスクリプトを作成します
- 4行目、モデルのスクリプトを保存します。
量子化でモデルを軽くする。
モバイルで使うために、モデルを量子化して、軽くします。
# Use 'x86' for server inference (the old 'fbgemm' is still available but 'x86' is the recommended default) and 'qnnpack' for mobile inference.
backend = "x86" # replaced with qnnpack causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("fbdeit_scripted_quantized.pt")
ここは、変換方法なので、簡単に説明します。深く考えずに真似すればいいと思います。
- 2~4行目、量子化の設定をします
- 6行目、量子化の実行。
- 7行目、スクリプト作成
- 8行目スクリプト保存。
量子化により、なんと89MBまで小さくなります。(もともと346MBなので、74%削減)
ここまで、軽い画像認識モデルはないですね。。。
量子化モデルを実行して確認
量子化したスクリプトの確認です。量子化しても結果は同じになります。
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
DeiT の最適化
モバイル向けのスクリプト作成の最後のステップとして、最適化を行います。
詳細が記載されていないので不明ですが、なにかしらの最適化を行うようです。
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")
- 1行目で、最適化のライブラリをインポートします
- 2行目、量子化したモデルを最適化します
- 3行目、量子化して、最適化したモデルのスクリプトを保存します。
念のため確認
念のため、量子化、最適化後のモデルの推論結果を確認します。推論結果は変わらないです。
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
Lite インタープリターの使用
ライトインタープリターというものを使うと、さらにモデルをちいさく、推論を高速にできるようです。
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")
- 1行目、ライトインタープリターでスクリプトを保存します
- 2行目、保存した、ライトインタープリタースクリプトで実行するモデルを読み込みます。
推論速度の比較(ここまで登場した4モデルで)
ここまで登場した4モデルの推論速度を比較します。すべてgpuを使用しません。
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
out = ptl(img)
- 1~2行目、(1)軽量化を行っていないモデルを読み込みます
- 3~4行目、(2)モバイル向けにスクリプト化されたモデルを読み込みます
- 5~6行目、(3)量子化、スクリプト化されたモデルを読み込みます
- 7~8行目、(4)量子化、スクリプト化、最適化されたモデルを読み込みます
- 9~10行目、(5)量子化、スクリプト化、最適化、Liteインタープリターを使用したモデルを読み込みます。
気になる推論速度の実行例は、
- (1)1236.69ms
- (2)1226.72ms
- (3)593.19ms
- (4)598.01ms
- (5)600.72ms
1s(gpuなしで、1000msを切るとは速いですね)
推論速度の検証結果をデータフレームにまとめる
下記では、推論速度の検証結果をpandasデータフレーム形式にまとめています。
import pandas as pd
import numpy as np
df = pd.DataFrame({'Model': ['original model','scripted model', 'scripted & quantized model', 'scripted & quantized & optimized model', 'lite model']})
df = pd.concat([df, pd.DataFrame([
["{:.2f}ms".format(prof1.self_cpu_time_total/1000), "0%"],
["{:.2f}ms".format(prof2.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof2.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof3.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof3.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof4.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof4.self_cpu_time_total)/prof1.self_cpu_time_total*100)],
["{:.2f}ms".format(prof5.self_cpu_time_total/1000),
"{:.2f}%".format((prof1.self_cpu_time_total-prof5.self_cpu_time_total)/prof1.self_cpu_time_total*100)]],
columns=['Inference Time', 'Reduction'])], axis=1)
print(df)
- 1~2行目、pandasとnumpyをインポート
- 4行目、データフレームを作成、とりあえずヘッダー
- 5行目以降、それぞれのモデルでの推論時間と、オリジナルのモデルと比較した推論時間の削減率をひょじします。
まとめ
Facebookが作成している、PytorchのVisionTransformer DeiTのサンプルコードについて解説しました。このモデルは、従来のCNNなどに比べて非常に軽量で精度も高いです。
しかも、比較的容易に実装できますので、ぜひ使いこなせるようになりたいですね。