Mediapipe(google)で画像認識AIのサンプルコード解説。初心者向けです。

データサイエンス

簡単に機械学習を実装できるMediaPipeをgoogleが公開しています。

機械学習のライブラリは年々増えてきて、何を使うか悩んでしまいますね。実装の手軽さという点で優れているのが、googleのMediaPipeです。 

手軽に、スマホやPCのカメラと連携して「画像認識」「物体検知」「セグメンテーション」ができてしまいます。今回は、MediaPipeの中で、「画像認識」のサンプルコードを解説します。初心者向けになります。

はじめに

今回説明するサンプルコードは、学習済みの画像認識モデルを使って、手持ちの画像データを分類するコードです。非常に軽量なモデルをつかっているため、GPUなしでも、実用的なスピードで動作します。

サンプルコードはこちらです。

Image classification guide for Python  |  MediaPipe  |  Google for Developers

サンプルコード解説

準備

!python -m pip install mediapipe

mediapipeをインストールします。コマンドラインで実行するときは”!”抜きましょう。

!wget -O classifier.tflite -q https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite

学習済みのモデルをwebページからダウンロードします。コマンドラインで実行するときは”!”抜きましょう。

Efficientnetの軽量モデルですね。

インポート

import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision

はじめに、mediapipeの各種ライブラリをインポートします

画像表示関連の関数

少し長いので初心者のかたは読み飛ばしてもOKです。本題のとことではないです。。。

from matplotlib import pyplot as plt
plt.rcParams.update({
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.spines.left': False,
    'axes.spines.bottom': False,
    'xtick.labelbottom': False,
    'xtick.bottom': False,
    'ytick.labelleft': False,
    'ytick.left': False,
    'xtick.labeltop': False,
    'xtick.top': False,
    'ytick.labelright': False,
    'ytick.right': False
})


def display_one_image(image, title, subplot, titlesize=16):
    """Displays one image along with the predicted category name and score."""
    plt.subplot(*subplot)
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize), color='black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(images, predictions):
    """Displays a batch of images with the classifications."""
    # Images and predictions.
    images = [image.numpy_view() for image in images]

    # Auto-squaring: this will drop data that does not fit into square or square-ish rectangle.
    rows = int(math.sqrt(len(images)))
    cols = len(images) // rows

    # Size and spacing.
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols, 1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))

    # Display.
    for i, (image, prediction) in enumerate(zip(images[:rows*cols], predictions[:rows*cols])):
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols) * 40 + 3
        subplot = display_one_image(image, prediction, subplot, titlesize=dynamic_titlesize)

    # Layout.
    plt.tight_layout()
    plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()
  • 1行目:matplotlibをインポート
  • 2~15行目:matplotの設定をしています。グラフではなく、画像なので、線などを表示しないようにしています。
  • 18~24行目:1枚画像を表示する関数
  • 20行目:plt.subplotで画像複数画像描画するときのどこに書くかを指定します
  • 21行目:画像表示
  • 23行目:タイトルがあるならば
  • 24行目:画像にタイトルを書きます
  • 25行目:subplotを返します
  • 26~52行目:バッチ処理で画像を描画する関数です。中で1枚画像を表示する関数をつかってます。
  • 29行目:画像データをnumpy形式に変換。リスト内包表記だと一行で済みますね。
  • 32行目:複数画像表示するときの、行数。縦方向に画像何個表示するか。
  • 33行目:複数画像表示するときの、列数。横方向に画像何個表示するか。
  • 36行目:Figのサイズ
  • 37行目:スペースのサイズ
  • 38行目:サブプロットの行数、列数を指定しておきます。
  • 39行目:行<列数ならば
  • 40、42行目:figのサイズ指定
  • 45行目:画像のリストと予測結果のリストをfor文で回します
  • 46行目:画像のタイトルサイズを指定。subplotの行・列数に合わせて
  • 47行目:1枚づつ表示する関数を使って表示
  • 50行目:狭い感じのレイアウト
  • 51行目:subplotの間隔を設定
  • 52行目:表示

テスト画像をダウンロードして用意する場合

import urllib

IMAGE_FILENAMES = ['burger.jpg', 'cat.jpg']

for name in IMAGE_FILENAMES:
  url = f'https://storage.googleapis.com/mediapipe-tasks/image_classifier/{name}'
  urllib.request.urlretrieve(url, name)

なんでもいいのですが、テスト用の画像を準備します。上の例では、ハンバーガと猫の画像をダウンロードしています。(ここは本題ではないので、あまり説明しないです。。。)

テスト画像を自分で用意して、colabにアップロードする場合

from google.colab import files
uploaded = files.upload()

for filename in uploaded:
   content = uploaded[filename]
   with open(filename, 'wb') as f:
     f.write(content)
 IMAGE_FILENAMES = list(uploaded.keys())

 print('Uploaded files:', IMAGE_FILENAMES)
  • 1行目:colabのライブラリをインポート
  • 2行目:ファイルをアップロードするオブジェクト
  • 4行目:ファイルの数だけfor分回す
  • 6行目:元々のファイル名のまま保存します。colabのrootディレクトリに
  • 7行目:ファイル保存します
  • 8行目:ファイル名のリスト作成
  • 10行目:ファイル名を表示

テスト画像を表示して確認

import cv2
from google.colab.patches import cv2_imshow
import math

DESIRED_HEIGHT = 480
DESIRED_WIDTH = 480

def resize_and_show(image):
  h, w = image.shape[:2]
  if h < w:
    img = cv2.resize(image, (DESIRED_WIDTH, math.floor(h/(w/DESIRED_WIDTH))))
  else:
    img = cv2.resize(image, (math.floor(w/(h/DESIRED_HEIGHT)), DESIRED_HEIGHT))
  cv2_imshow(img)


# Preview the images.

images = {name: cv2.imread(name) for name in IMAGE_FILENAMES}
for name, image in images.items():
  print(name)
  resize_and_show(image)
  • 1~3行目:各種ライブラリをインポート。opencvを使っていますね。
  • 5行目:リサイズ後の画像の高さ
  • 6行目:リサイズ後の画像の幅
  • 8行目:リサイズと表示
  • 9行目:画像データの高さと、幅
  • 10行目:縦長画像ならば
  • 11行目:画像をリサイズ。幅を480pixにして、高さは縦横比を変えないように調整
  • 12行目:横長なら
  • 13行目:高さを480pixにして、幅は縦横比を変えないように調整。
  • 14行目:画像表示。面倒なんで、cv2で表示。
  • 19行目:画像データをcv2形式で読み込みます。辞書形式で格納
  • 20行目:画像の辞書をfor文で回して
  • 21行目:画像名を表示して
  • 22行目:画像をリサイズして、表示

今回試すのはこんな画像です

推論と結果表示

# STEP 1: Import the necessary modules.
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python.components import processors
from mediapipe.tasks.python import vision

# STEP 2: Create an ImageClassifier object.
base_options = python.BaseOptions(model_asset_path='classifier.tflite')
options = vision.ImageClassifierOptions(
    base_options=base_options, max_results=4)
classifier = vision.ImageClassifier.create_from_options(options)

images = []
predictions = []
for image_name in IMAGE_FILENAMES:
  # STEP 3: Load the input image.
  image = mp.Image.create_from_file(image_name)

  # STEP 4: Classify the input image.
  classification_result = classifier.classify(image)

  # STEP 5: Process the classification result. In this case, visualize it.
  images.append(image)
  top_category = classification_result.classifications[0].categories[0]
  predictions.append(f"{top_category.category_name} ({top_category.score:.2f})")

display_batch_of_images(images, predictions)
  • 2~5行目:メディアパイプの必要なモジュールをインポートします。
  • 8行目:モデルの設定です。モデルのファイルを指定します。
  • 9行目:ここも、モデル設定です。引数で8行目で指定したモデルと、max_resultsを与えています。max_resultsは予測結果の上位いくつまでを結果として返すかです。
  • 11行目:予測モデルのオブジェクトを作成しています
  • 13行目:画像をためる空のリスト
  • 14行目:結果をためる空のリスト
  • 15行目:画像のパスが入っているリストから、画像パスを1個づつ取り出してfor分回します。
  • 17行目:画像を読み込みます。
  • 20行目:予測モデルに画像を入力して、分類結果を出します
  • 23行目:画像データはリストにためておいて
  • 24行目:分類結果のtop1をとりだして。(9行目で、top4出力にするのはなぜでしょうね?。。。)
  • 25行目:分類結果とリストにためておきます。
  • 27行目:画像と、分類結果を表示します。上のほうで作成した自作関数をつかって。

手持ちデータでためした結果はこれです。

大型飛行機と間違いなく分類できてますね。

まとめ

PC、スマホ、カメラなどと連携しやすい、Mediapipeの画像認識サンプルコードの解説をしました。初心しゃの方は、途中の関数などわかりにくいと思いますが、そこは飛ばして、Mediapipeを使って実際に手持ち画像を分類してみましょう。

tensroflowやpytorchと比べると簡単かと思いますので、これから主流になるかもしれません。ぜひ、身に着けてください。

タイトルとURLをコピーしました