Amazon SageMaker を使っておでん種を検出したかった

TL;DR

SageMaker を使って素人が見よう見まねで機械学習してみた結果、散々な結果に終わった。

大根・しらたき・がんもどきを学習させたのに大根しか出ねぇ……
(図示するしきい値をゆるくすると大根以外も検出されてたので、一応、バグの類ではない模様)

ちなみに、しらたきだけ学習させたときはいい感じだったので、しらたき検出器だけで良いのであれば実用的かもしれない。

SageMaker とは

Amazon Web Services ブログより

だれでも機械学習ができる環境を御膳立してくれるすごいサービス。
具体的には、下記の要素を提供しているらしい。

  • Jupyter Notebook 用インスタンス
    • 頭いい人たちがよく使ってる Web GUI な Python 実行環境
    • トレーニング用データの前処理 (Data Wrangling) 等を行う
    • AWS によるサンプルノートブックがたくさんある
  • モデルトレーニングジョブの実行環境
    • トレーニングジョブをぶん回す環境
    • スケールアップ・スケールアウトも簡単にできる
  • モデルホスティング
    • 構築したモデルを使って推論をおこなうエンドポイントをデプロイ・ホスティング

あとは己にイカした脳みそが搭載されていれば、面倒くさい環境構築などをすっ飛ばしていい感じに機械学習に取り掛かれる。

SageMaker (の組み込みアルゴリズム) でできること

Apache MXNetTensorFlow といった有名な機械学習用フレームワークも使用できるが、
初心者はとりあえず SageMaker の組み込みアルゴリズムを使えば良いのではないだろうか。

今回は、画像中のおでん種の検出を行うので、物体検出 用のアルゴリズムを使用した。

手順

AWS のサンプルノートブック をベースに大根・しらたき・がんもどきのデータセットを独自に用意した。

トレーニング用画像の取得

Google 画像検索の結果ページに表示されるサムネイル画像を拝借した。 拾ってきたままだと、関係ない画像やトレーニング用として不適切そうな画像 (例: 食べかけの大根) が紛れ込んでいるので手動で選別した。

最終的に下記の枚数になった。

おでん種枚数
大根64 枚
がんもどき34 枚
しらたき32 枚

画像の水増し (Image Augmentation)

機械学習には大量のトレーニング用データが必要になる。 例として、よく見かける The MNIST database の手書き数字データには 6 万枚のトレーニング用画像が含まれる。

Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. “Gradient-based learning applied to document recognition.” Proceedings of the IEEE, 86(11):2278-2324, November 1998

よって、今回は選りすぐりの少ないおでん画像を Augmentor を使って水増しした。
具体的には、下記の加工を施したバリエーションを増やした。

  • 上下左右反転
  • 回転
  • コントラスト変更
  • 歪ませる

これにより、大根・がんもどき・しらたき画像がそれぞれ 100 枚ずつとなった。
(これでも少ないが、後述のアノテーション作業がつらいので……)

アノテーション

今回用いた物体検出 は、教師あり学習である。 つまり、先の手順で拾ってきた画像に対して正解データを用意する必要があるんですね。

…………と、このように Microsoft の VoTT (Visual Object Tagging Tool) を用いて、上の動画のような作業を全画像 (300枚) に対しておこなった。

タグ付け後は、クラスメソッドさんの記事にあるスクリプトを使って VoTT が出力した JSON を指定された JSON 形式 へ変換した。

S3 へアップロード

下記のようなディレクトリ構造になるよう、画像ファイルと JSON ファイルを S3 へアップロードする。

  • s3://your_bucket/train/xxxx.jpg
    • トレーニング用画像ファイル
  • s3://your_bucket/train_annotation/xxxx.json
    • トレーニング用画像ファイルのアノテーション JSON
  • s3://your_bucket/validation/yyyy.jpg
    • 検証用画像ファイル
  • s3://your_bucket/validation_annotation/yyyy.json
    • 検証用画像ファイルのアノテーション JSON

マネジメントコンソールのアップロード画面へ適宜ファイルをドロップすれば済む。

トレーニング・デプロイ・推論

とりあえず、下記のリンクから Jupyter Notebook を開き、SageMaker Examplesobject_detection_image_json_format.ipynbUse。トレーニング用画像周りの処理 (Data Preparation のところ) などを変更して実行。

Setup

%%time
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
print(role)
sess = sagemaker.Session()

bucket = '<your_s3_bucket_name_here>'
prefix = 'DEMO-ObjectDetection'

from sagemaker.amazon.amazon_estimator import get_image_uri

training_image = get_image_uri(sess.boto_region_name, 'object-detection', repo_version="latest")
print (training_image)

Prepare dataset

画像等は前項で S3 アップロード済み。

%%time

train_channel = prefix + '/train'
validation_channel = prefix + '/validation'
train_annotation_channel = prefix + '/train_annotation'
validation_annotation_channel = prefix + '/validation_annotation'

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)
s3_train_annotation = 's3://{}/{}'.format(bucket, train_annotation_channel)
s3_validation_annotation = 's3://{}/{}'.format(bucket, validation_annotation_channel)

s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

Training

od_model = sagemaker.estimator.Estimator(training_image,
                                         role,
                                         train_instance_count=1,
                                         train_instance_type='ml.p3.2xlarge',
                                         train_volume_size = 50,
                                         train_max_run = 360000,
                                         input_mode = 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess
                                         )

#
# Object Detection Hyperparameters
# https://docs.aws.amazon.com/sagemaker/latest/dg/object-detection-api-config.html
#
od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=3, # 検出する物体の数
                             mini_batch_size=4, # なんかエラーが出てジョブが失敗する場合は減らすと良いらしい
                             epochs=50,
                             learning_rate=0.001,
                             lr_scheduler_step='10',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=512,
                             label_width=350,
                             num_training_samples=300 # トレーニング用画像数
                             )


train_data = sagemaker.session.s3_input(s3_train_data, distribution='FullyReplicated', 
                        content_type='image/jpeg', s3_data_type='S3Prefix')
validation_data = sagemaker.session.s3_input(s3_validation_data, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
train_annotation = sagemaker.session.s3_input(s3_train_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')
validation_annotation = sagemaker.session.s3_input(s3_validation_annotation, distribution='FullyReplicated', 
                             content_type='image/jpeg', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 'validation': validation_data, 
                 'train_annotation': train_annotation, 'validation_annotation':validation_annotation}

# 10分とか、そういうオーダーで時間がかかる
od_model.fit(inputs=data_channels, logs=True)

Hosting

object_detector = od_model.deploy(initial_instance_count = 1,
                                 instance_type = 'ml.m4.xlarge')

Inference (推論)

お待ちかねの推論。

画像は下記のボタンからノートブックインスタンスへアップロードできる。

# Jupyter Notebook にアップロードした画像ファイル名
file_name = 'XXXX.jpg'

with open(file_name, 'rb') as image:
    f = image.read()
    b = bytearray(f)
    ne = open('n.txt','wb')
    ne.write(b)

import json

object_detector.content_type = 'image/jpeg'
results = object_detector.predict(b)
detections = json.loads(results)
print (detections)

def visualize_detection(img_file, dets, classes=[], thresh=0.6):
        """
        visualize detections in one image
        Parameters:
        ----------
        img : numpy.array
            image, in bgr format
        dets : numpy.array
            ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
            each row is one object
        classes : tuple or list of str
            class names
        thresh : float
            score threshold
        """
        import random
        import matplotlib.pyplot as plt
        import matplotlib.image as mpimg

        img=mpimg.imread(img_file)
        plt.imshow(img)
        height = img.shape[0]
        width = img.shape[1]
        colors = dict()
        for det in dets['prediction']:
            (klass, score, x0, y0, x1, y1) = det
            if score < thresh:
                continue
            cls_id = int(klass)
            if cls_id not in colors:
                colors[cls_id] = (random.random(), random.random(), random.random())
            xmin = int(x0 * width)
            ymin = int(y0 * height)
            xmax = int(x1 * width)
            ymax = int(y1 * height)
            rect = plt.Rectangle((xmin, ymin), xmax - xmin,
                                 ymax - ymin, fill=False,
                                 edgecolor=colors[cls_id],
                                 linewidth=3.5)
            plt.gca().add_patch(rect)
            class_name = str(cls_id)
            if classes and len(classes) > cls_id:
                class_name = classes[cls_id]
            plt.gca().text(xmin, ymin - 2,
                            '{:s} {:.3f}'.format(class_name, score),
                            bbox=dict(facecolor=colors[cls_id], alpha=0.5),
                                    fontsize=12, color='white')
        plt.show()

object_categories = ['daicon', 'ganmodoki', 'shirataki']
threshold = 0.2

visualize_detection(file_name, detections, object_categories, threshold)

ここまで正常に実行できたら、冒頭の画像のような検出結果が可視化される。

後片付け

放置すると、1 エンドポイントで数百 USD / Month になるので、不要なエンドポイントは必ず削除。

sagemaker.Session().delete_endpoint(object_detector.endpoint)

※マネジメントコンソールからも削除可能
※不要なら、ノートブックインスタンスも同様に削除

まとめ

  • 大根しか出ねぇ……
  • とりあえずサンプルノートブックを模倣すれば、機械学習素人でもそれっぽいことができるのですごい
    • ソフト・インフラ両面からの手厚いサポートにより、機械学習をコモディティ化を促進させてインフラで儲けようとする AWS の企みが
  • 次を真面目にやればもっとマシになるはず