Pytorch CNN 画像分類・画像認識 サンプルコード解説 初心者向けです!

データサイエンス

 今回も、画像認識のサンプルコードを解説します。facebookが作成している、Pytorchで、アルゴリズムは王道のCNN(convolutional neural network)です。

細かいところは、省いて説明していきますので、動作することを目指しましょう!

はじめに

今回紹介するのは、サンプルコードはこちらのものです。pytorchは、tensorflow kerasに比べてコードが長くなりますが、どんな計算を行っているか具体的にイメージしやすくなりますので、ぜひ最後までお付き合いください。

Transfer Learning for Computer Vision Tutorial — PyTorch Tutorials 2.2.2+cu121 documentation

サンプルコード解説

ライブラリのインポート

# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

cudnn.benchmark = True
plt.ion()   # interactive mode

各種ライブラリのインポートです

  • 6行目:pytorch
  • 7行目:pytorchのニューラルネットワーク関連
  • 8行目:最適化関連
  • 9行目:最適化の学習率のスケジューラ
  • 10行目:cudnnを使うためのもの。gpuを使うときの設定ですね
  • 11行目:numpy
  • 12行目:pytorchのコンピュータビジョン関連
  • 13行目:画像のデータセット、モデル、画像変換・画像処理
  • 14行目:マットプロット
  • 15行目:時間計測
  • 16行目:ファイル操作
  • 17行目:コピー
  • 19行目:cudnn使います
  • 20行目:マットプロットをインタラクティブモードにします。

画像データを読み込んだり前処理を加えるところ

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  • 3~16行目データの前処理です 4~9行目が学習データ、10~15行目が検証データに行う処理です。
  • 5行目:画像からランダムな位置で切り取って、224×224pixにリサイズします。
  • 6行目:ランダムに水平方向に画像を反転させます。
  • 7行目:画像データをテンソルに変換します。
  • 8行目:標準化方法に合わせて、数値を変換します。初めの3つが、RGBそれぞれの平均値、次の3つがRGBの標準偏差です。
  • 11行目:256×256pixにリサイズします。
  • 12行目:画像中心224×224pixを切り取ります
  • 13行目:画像データをテンソルに変換します。
  • 14行目:標準化します。
  • 18行目:画像データのフォルダ
  • 19~20行目:データセットを辞書形式で作成しておきます。上で作成したtransfromsを使います
  • 22行目~24行目:データローダーの設定です。学習データと、検証データそれぞれで設定します。batch_sizeは一度に計算する画像データ数。shuffleは、データを取り出す順番をシャッフルするか。num_workers、計算機の数。
  • 25行目:データセットのサイズを把握しておきます
  • 26行目:各データのラベル名
  • 28行目:デバイス設定。gpuが使えるならgpuを使います。つかえないなら、仕方がないのでcpu。

試しに画像をいくつか見てみよう

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
  • 1~11行目:画像表示する自作関数です。
  • 3行目:numpyの配列を変換します。RGBの順番にします。
  • 4行目:標準化した時の平均値
  • 5行目:標準化した時の標準偏差
  • 6行目:標準化します。
  • 7行目:0~1に制限します
  • 8行目:マットプロットで描画します
  • 9~10行目:タイトルを表示します
  • 11行目:0.001s表示を止めます。
  • 15~20行目:実施に自作関数をつかって表示
  • 15行目:画像を一枚づつ取り出します
  • 18行目:グリッド上に画像を並べて表示します
  • 20行目:自作関数で画像データを表示します

学習(の関数)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model
  • 1~57行目:学習するための関数。長いです。。。頑張って理解しましょう!
  • 4行目:一番よかったモデルの重みをコピーしておきます。
  • 5行目:最高精度。ゼロにしておきます。
  • 7行目から、1エポックずつ計算
  • 12行目:学習と検証の切り替え。学習、検証の順番です
  • 13~16行目:モデル設定を学習モードか、推論モードかにする
  • 18行目:lossの初期値
  • 19行目:corretctsの初期値
  • 22行目から、データを取り出しつつ学習または検証していきます
  • 23~24行目:デバイス(cpuまたは)
  • 27行目:勾配をゼロにしておきます
  • 29行目からは、順方向の計算です
  • 31行目:学習ならば、パラメータを調整できるようにしておきます
  • 32行目:モデルに画像を入れて出力計算
  • 33行目:誤差を計算します。criterionで出力と正解ラベルの比較をおこないます
  • 36行目以降で逆方向の計算
  • 37行目:学習中モードならば。
  • 38行目:誤差逆伝搬させて、
  • 39行目:重みを調整します。
  • 42行目:ランニングロスを集計
  • 43行目:ランニングコレクトを集計
  • 44~45行目:学習モードならば、スケジューラのステップを進めます
  • 47行目:今のエポックでのロスを計算
  • 48行目:今のエポックでのcorrects計算
  • 53行目~55行目で、ベスト解が出たらモデルをほぞんしておきます
  • 54行目:ベストaccを更新
  • 55行目:ベスト解を出したモデルの重みを更新
  • 57行目:見やすいように開業を入れて。
  • 59行目:時間を計っておきます
  • 60行目:学習が終わったことを出力
  • 61行目:最高精度を出力
  • 64行目:モデルにベスト解の時の重みをロードしておいて。(そのままだと、最後に計算したときの重み)
  • 65行目:モデルを返す

長いです。。。が、CNNの学習がどんなことをしているかよくわかりますね!

推論 画像を数枚で試してみる(自作関数)

def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)
  • 2行目:学習済みのモデル
  • 3行目:モデルを推論モードにする
  • 4行目:画像の通し番号
  • 5行目:フィギュアのオブジェクトをつくっておいて
  • 7行目:勾配は変えずに
  • 8行目:データローダーから順番に画像をとってくる(バッチ数分の画像を一度に)
  • 9行目:デバイスに画像をのっける
  • 10行目:デバイスにラベルを乗っける
  • 12行目:モデルに画像を入れて、推論実行
  • 13行目:確率が一番大きいのを計算して、推論ラベルを計算
  • 15行目画像を書いていきます。
  • 16行目:画像の通し番号を一つ増やして
  • 17行目:サブプロットの指定です。画像数//2 行、2列、で画像を表示します。
  • 18行目:軸は邪魔なのでけしておきます
  • 19行目:タイトルは、予測したクラス名
  • 20行目:忘れちゃいけない imshowです。
  • 22行目:最後の画像まで終わったら。
  • 23行目:モデルを学習モードに戻しておきます
  • 24行目:関数抜けます。
  • 25行目:データがなかったときも。。。モデルを学習モードに戻しておきます。手が込んでますね。。。

学習と推論の自作関数を 使います。resnet18のファインチューニングです。

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

ここからは、上で作った関数を使うのでらくですよ

  • 1行目:resnet18を読み込みます。イメージネットで学習済みのモデルです。
  • 2行目:resnet18の特徴量の数です
  • 3行目以降は、resnet18の出力を2クラス分類になるように改造します
  • 5行目:全結合は、特徴量の数から、2つの出力にまとめるものにしますよ
  • 9行目:精度は、クロスエントロピーです。2クラスなのでバイナリーですね。
  • 12行目:オプティマイザーの設定です。シンプルなSGDで学習率0.001、モメンタム0.9です。
  • 15行目:学習率のスケジューラ設定です。step_sizeが7、gammaが0.1です。ここは、初心者のかたは、気になさらず。

学習(自作関数使って)

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)

上で作成した関数を使って学習します。ファインチューニングなので、エポック数25くらいでいいのでしょう。

推論と可視化(自作関数使って)

visualize_model(model_ft)

1行で推論結果を表示できます

特徴抽出器を固定する場合(転移学習)

学習済みモデルの特徴抽出器部分の重みを変えずに、全結合部だけ追加学習するほうほうもありますよ。転移学習とか呼ばれています。

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

2~3行目が、ファインチューニングと異なるところで、重みの更新をしないように設定しています。

まとめ

いかかでしたが、少し長かったかもしれませんが、CNNの学習・推論をイメージできたのではないでしょうか?

tensorflow版の CNNのサンプルコード解説や

Tensorflow keras CNN(convolutional neural network)のサンプルコード説明 初心者向けです。
初心者の方向けに、Tensorflow keras CNN(convolutional neural network)のサンプルコードを解説します。画像認識のアルゴリズムの王道ですので、画像系AIのエンジニアを目指す方は、ぜひ参考になさってください。

CNNより主流になりつつあるVisionTransformaerの解説もよかったら参考になさってください。

Vision transformer (ViT)を用いた画像認識のコード解説。初心者向 けにtensorflow keras APIのコードをわかりやすく解説します。
画像認識のアルゴリズムで最近注目されている、Vision Transformer(ViT)のサンプルコードを解説します(Tensorflow keras API)。初心者の方にも理解しやすいように、必要以上に情報を詰め込まずに平易な文章で説明します。まずは手軽に実行してみましょう!
Pytorch、ビジョントランスフォーマー DeiTサンプルコード解説
Facebookが公開している画像分類タスク向け、VisionTransforme(ビジョントランスフォーマー)モデル、DeiTのサンプルコードを、初心者の方向けに解説します。細部の理論はとりあえずおいて、動作することを目標に解説しますので、ご活用ください。

タイトルとURLをコピーしました