バレンタインデーにもらったチョコが本命かどうか判定するAIを実際に作って公開した(2019)
09
- 2月
2019
Posted By : boomin
バレンタインデーにもらったチョコが本命かどうか判定するAIを実際に作って公開した(2019)

TL;DR

今年もバレンタインデーが近いですね!去年のものより、AIもフロントエンドも刷新してリリース!です。

バレンタインでもらったチョコが、本命かどうかは、世の男性にとって、重大な問題です。 女性心理の機微に疎い男子中高生、理系男子学生、男性エンジニア(偏見) を救うべく、実際にもらったチョコが本命かどうか判定する AIを開発しました

もちろん、このAIは完璧ではありません。本命のつもりなのに義理と、あるいはその逆で、義理のつもりなのに本命の判定してしまうこともあり得ます。

ですので、受け取った男性諸氏が判定をするときは、一人でこっそり楽しむようにしてくださいね。。。 相手の女性の気持ちを考え、決して踏みにじることはしないようにお願いします!

注)この文章をもとに、Qiitaにも投稿しています。


1. 男性を悩ませるバレンタイン問題

Advertisements

1.1 バレンタインデー

世界中にあるイベントの一つに、2/14のバレンタインデー(Valentine’s Day)があります。 国や地域によって、その日の過ごし方などは異なるようですが、ここ日本では、女性から男性へチョコレートを贈る日、という お菓子会社の戦略によって定着した文化 があります(例えば[1]参照。諸説あり)。

[1]によれば、1970年代後半頃に本命チョコを贈る習慣が定着したようです。ここから、さらにお菓子業界の商機ととらえた全国飴菓子工業協同組合は、義理チョコとホワイトデーという文化を創り出し、1984年をホワイトデー定着の年としました。

1.2 バレンタインデーの悲劇

しかし、このような文化は悲劇を起こしました。それは、女性から贈られたチョコが、本命なのか義理なのかがはっきりしないケースが登場したためです。はっきり口にせず、以心伝心や忖度を是 とする日本の文化が、多くの迷える男性を生み出して しまうことになりました。

このような背景から、贈られたチョコレートが、本命なのか義理なのかをはっきりさせる社会的ニーズが産まれました。図1に、この状況証拠を示します。世の中のどれだけの人が検索をすれば、Googleの検索サジェスチョンにこのような結果をもたらすのでしょうか。これは、もはや国民的課題といえるでしょう。 以降、これを バレンタイン問題 と呼ぶことにします。

図1

図1:バレンタイン問題 – バレンタインのチョコレートが本命か義理かを判断する社会的ニーズ

1.3 バレンタインデーの救世主

一方、最近の機械学習手法の発達は目覚ましい物があります。AIすなわち、機械学習がこの バレンタイン問題 を解決する手段になりえることが示されました[2]。これは画期的な観点であると思われましたが、開発したエンジニアのみが実行、評価できるものでした。これでは、__迷える男子中高生、理系男子大学生、そして男性エンジニア__を救う ことはできません(筆者の独断と偏見による)。

また、男性の立場からすると、ホワイトデーにお返しを行う義務が、暗黙のうちに発生します。ここで不適切な対応をとると、その後の長期間にわたり、生きづらい人生を歩むことになりかねません。 そのため、贈られたチョコレートに対して、どのくらいのお返しを行えばよいのか、その指針を示すことができれば、悩める男性に対して有効なサービスとなり得ると思われます。図2に、悩める男性像の一例を示します。

図2 バレンタインのお返しに苦悩する男性の一端

1.4 バレンタイン問題の解決方法

以上のことから、本記事では[2]をさらに発展させ、バレンタイン問題の一助となることを目指すものです。

  • 本命チョコと義理チョコを判別する学習モデルを開発する
    • 二値分類と呼ばれる、機械学習における代表的な手法を用いたアプローチをとる
    • これを本命・義理チョコ分類問題と呼ぶ
  • 開発した、構築した学習モデルをサービス化して公開する
    • PC、スマホのどちらからでも使えるように、Webサービス化を行う
  • 義理チョコと本命チョコを分類する際の確率を算出する
    • 定量的に判断指標を示すことで、悩める男性自身の判断を促す
    • 確率を用いることで、ホワイトデーのお返し予算の提案も併せて行う

実際のイメージを、以下に示します。

図3:学習モデルと本命/義理チョコの判別のイメージ

なお、筆者が人生でチョコレートを貰った回数(個数)は、年齢のおよそ25%以下である。 チョコ欲しい

2. DNNでの学習

昨年、ここに投稿したように、自力で画像をすべて集め、それだけを元にしてCNNで学習をさせました。

が、いまどきそれではイマイチですよね。

そこで、転移学習させることで、より優秀な学習モデルを構築することを目指します!

2.1 学習データの収集

何はともあれ、学習データの収集をやらねばならないことには変わりません。 そこで、ここに投稿した方法で、学習用の画像を集めてきました。

2.1.1 学習用に画像収集した検索クエリ

このリンク先の記事で示したsource codeでは、2パターンの検索クエリしか示していませんが、実際には、以下のような検索パターンで画像を収集しました。

config = {
  "Records": [
    {
      "keywords": "バレンタインデー 本命 チョコ",
      "limit": 10000,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"honmei",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 本命 チョコ 手作り",
      "limit": 10000,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"honmei",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 本命チョコ  ラッピング",
      "limit": 10000,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"honmei",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 義理 チョコ",
      "limit": 10000,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"giri",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 義理 ばらまき",
      "limit": 500,
      "no_numbering": True,
      #"format":"jpg",
      "output_directory": "images",
      "image_directory":"giri",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 義理 ブラックサンダー",
      "limit": 200,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"giri",
      "chromedriver":"/path/to/chromedriver.exe",
    },
    {
      "keywords": "バレンタインデー 義理 ネタ",
      "limit": 500,
      "no_numbering": True,
      "output_directory": "images",
      "image_directory":"giri",
      "chromedriver":"/path/to/chromedriver.exe",
    },
  ]
}
本命チョコ画像義理チョコ画像

いまいち雰囲気の違いが分かりにくいですが、そこはきっと、Neural Networkが特徴を獲得してくれるに違いありません!(キリッ

Advertisements

2.1.2 収集した画像の前処理

取得した画像は、そのままでは学習用データとして使えません。 その理由は、

  • 画像の形式がバラバラ
    • jpg/png/gifといったファイル形式
    • 画像のチャネル数もバラバラ
  • 同じ画像が、別ファイル名で複数存在している
    • 本命チョコ画像、義理チョコ画像の両方に同じ画像があったら、何を学習したのかわからなくなっちゃいますよね

ということで、これを解決せねばなりません。そこで、画像形式のRGBA形式に一括変換と、同じshapeかつ同じファイル名を画像を一括削除しました。

しかしこれだけでは、同じ画像サイズが異なるものを除外できません。 そこで、同じ特徴を持つ画像を検索して、削除することを行いました。こんな感じです。

うん、確かに同じ画像ですよね!

こうした地道な努力を重ねて、最終的に得られた学習用の画像の枚数は、以下の表のようになりました。

種類最終的に得られた枚数
本命チョコ画像1,729枚
義理チョコ画像2,057枚

なお、学習実行時にepoch単位で画像の水増しを行います。 水増しの程度は、300~1000倍に水増しします。

水増しに幅があるのは、バッチサイズなど実行時のパラメタによって動的に変更させているからです。 これも、GPUメモリに制限があるから。。。。。

2.2 学習モデル

以前、CNNを作った時は、学習モデルはすべて自前で構築して、0から学習を行いました。しかし、画像の機械学習でこの方法は、勉強のためにやってみるのはいいのですが、精度などあまりメリットがありません。

そこで、学習済みモデルを別途取得して、これを修正して利用する転移学習を行うことにします。 画像分類器にはkerasのデータセットとして用意されている「Xception」を利用させてもらうことにします。

Xceptionモデルを、以下のように修正します。

  • 入力層
    • 画像サイズ数とチャネル数を指定したレイヤを追加
  • 出力層
    • 出力が二値分類となるように、Xception最後の全結合層を削除して、自前で構築した全結合層を結合

cnn_model.py

# -*- coding: utf-8 -*-
"""
Created on Wed Jan 16 14:00:06 2019
@author: boomin
"""

import os,glob,sys
from keras.applications.xception import Xception
from keras.models import Sequential,Model
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.layers import GaussianNoise
from keras import backend as tensorflow_backend
from keras.layers import BatchNormalization
from keras import optimizers

def cnn_model(num_label, iw, ih, actFunc):
    # include_top=Falseによって、モデルから全結合層を削除
    input_tensor = Input(shape=(iw, ih, 3))
    xception_model = Xception(include_top=False, pooling='avg', weights="imagenet", input_tensor=input_tensor)
    #opt = Nadam(lr=0.0002, beta_1=0.9, beta_2=0.999, epsilon=None, schedule_decay=0.00001)
    
    # 全結合層の構築
    top_model = Sequential()
    top_model.add(Dropout(0.3))
    top_model.add(Dense(num_label,input_shape=xception_model.output_shape[1:]))
    #top_model.add(Activation("sigmoid"))
    top_model.add(Activation(actFunc))

    # 全結合層を削除したモデルと上で自前で構築した全結合層を結合
    model = Model(inputs=xception_model.input, outputs=top_model(xception_model.output))

    for layer in xception_model.layers[:-50]:
        layer.trainable = False

    model.compile(
      loss='binary_crossentropy',
      # optimizer=SGD(lr=1e-3, decay=1e-6, momentum=0.9, nesterov=True),
      optimizer = optimizers.Adam(lr=4e-3),
      #optimizer=Nadam(),
      metrics=['accuracy']
    )

    return model

上記のようにしてできたモデルは、以下のようになります。 長い。。。。。。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 48, 48, 3)    0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 23, 23, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, 23, 23, 32)   128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 23, 23, 32)   0           block1_conv1_bn[0][0]            
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 21, 21, 64)   18432       block1_conv1_act[0][0]           
__________________________________________________________________________________________________
block1_conv2_bn (BatchNormaliza (None, 21, 21, 64)   256         block1_conv2[0][0]               
__________________________________________________________________________________________________
block1_conv2_act (Activation)   (None, 21, 21, 64)   0           block1_conv2_bn[0][0]            
__________________________________________________________________________________________________
block2_sepconv1 (SeparableConv2 (None, 21, 21, 128)  8768        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_sepconv1_bn (BatchNormal (None, 21, 21, 128)  512         block2_sepconv1[0][0]            
__________________________________________________________________________________________________
block2_sepconv2_act (Activation (None, 21, 21, 128)  0           block2_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block2_sepconv2 (SeparableConv2 (None, 21, 21, 128)  17536       block2_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block2_sepconv2_bn (BatchNormal (None, 21, 21, 128)  512         block2_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 11, 11, 128)  8192        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_pool (MaxPooling2D)      (None, 11, 11, 128)  0           block2_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 11, 11, 128)  512         conv2d_1[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 11, 11, 128)  0           block2_pool[0][0]                
                                                                 batch_normalization_1[0][0]      
__________________________________________________________________________________________________
block3_sepconv1_act (Activation (None, 11, 11, 128)  0           add_1[0][0]                      
__________________________________________________________________________________________________
block3_sepconv1 (SeparableConv2 (None, 11, 11, 256)  33920       block3_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv1_bn (BatchNormal (None, 11, 11, 256)  1024        block3_sepconv1[0][0]            
__________________________________________________________________________________________________
block3_sepconv2_act (Activation (None, 11, 11, 256)  0           block3_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block3_sepconv2 (SeparableConv2 (None, 11, 11, 256)  67840       block3_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv2_bn (BatchNormal (None, 11, 11, 256)  1024        block3_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 6, 6, 256)    32768       add_1[0][0]                      
__________________________________________________________________________________________________
block3_pool (MaxPooling2D)      (None, 6, 6, 256)    0           block3_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 6, 6, 256)    1024        conv2d_2[0][0]                   
__________________________________________________________________________________________________
add_2 (Add)                     (None, 6, 6, 256)    0           block3_pool[0][0]                
                                                                 batch_normalization_2[0][0]      
__________________________________________________________________________________________________
block4_sepconv1_act (Activation (None, 6, 6, 256)    0           add_2[0][0]                      
__________________________________________________________________________________________________
block4_sepconv1 (SeparableConv2 (None, 6, 6, 728)    188672      block4_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv1_bn (BatchNormal (None, 6, 6, 728)    2912        block4_sepconv1[0][0]            
__________________________________________________________________________________________________
block4_sepconv2_act (Activation (None, 6, 6, 728)    0           block4_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block4_sepconv2 (SeparableConv2 (None, 6, 6, 728)    536536      block4_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv2_bn (BatchNormal (None, 6, 6, 728)    2912        block4_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 3, 3, 728)    186368      add_2[0][0]                      
__________________________________________________________________________________________________
block4_pool (MaxPooling2D)      (None, 3, 3, 728)    0           block4_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 3, 3, 728)    2912        conv2d_3[0][0]                   
__________________________________________________________________________________________________
add_3 (Add)                     (None, 3, 3, 728)    0           block4_pool[0][0]                
                                                                 batch_normalization_3[0][0]      
__________________________________________________________________________________________________
block5_sepconv1_act (Activation (None, 3, 3, 728)    0           add_3[0][0]                      
__________________________________________________________________________________________________
block5_sepconv1 (SeparableConv2 (None, 3, 3, 728)    536536      block5_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv1_bn (BatchNormal (None, 3, 3, 728)    2912        block5_sepconv1[0][0]            
__________________________________________________________________________________________________
block5_sepconv2_act (Activation (None, 3, 3, 728)    0           block5_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv2 (SeparableConv2 (None, 3, 3, 728)    536536      block5_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv2_bn (BatchNormal (None, 3, 3, 728)    2912        block5_sepconv2[0][0]            
__________________________________________________________________________________________________
block5_sepconv3_act (Activation (None, 3, 3, 728)    0           block5_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv3 (SeparableConv2 (None, 3, 3, 728)    536536      block5_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv3_bn (BatchNormal (None, 3, 3, 728)    2912        block5_sepconv3[0][0]            
__________________________________________________________________________________________________
add_4 (Add)                     (None, 3, 3, 728)    0           block5_sepconv3_bn[0][0]         
                                                                 add_3[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1_act (Activation (None, 3, 3, 728)    0           add_4[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1 (SeparableConv2 (None, 3, 3, 728)    536536      block6_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv1_bn (BatchNormal (None, 3, 3, 728)    2912        block6_sepconv1[0][0]            
__________________________________________________________________________________________________
block6_sepconv2_act (Activation (None, 3, 3, 728)    0           block6_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv2 (SeparableConv2 (None, 3, 3, 728)    536536      block6_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv2_bn (BatchNormal (None, 3, 3, 728)    2912        block6_sepconv2[0][0]            
__________________________________________________________________________________________________
block6_sepconv3_act (Activation (None, 3, 3, 728)    0           block6_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv3 (SeparableConv2 (None, 3, 3, 728)    536536      block6_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv3_bn (BatchNormal (None, 3, 3, 728)    2912        block6_sepconv3[0][0]            
__________________________________________________________________________________________________
add_5 (Add)                     (None, 3, 3, 728)    0           block6_sepconv3_bn[0][0]         
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1_act (Activation (None, 3, 3, 728)    0           add_5[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1 (SeparableConv2 (None, 3, 3, 728)    536536      block7_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv1_bn (BatchNormal (None, 3, 3, 728)    2912        block7_sepconv1[0][0]            
__________________________________________________________________________________________________
block7_sepconv2_act (Activation (None, 3, 3, 728)    0           block7_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv2 (SeparableConv2 (None, 3, 3, 728)    536536      block7_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv2_bn (BatchNormal (None, 3, 3, 728)    2912        block7_sepconv2[0][0]            
__________________________________________________________________________________________________
block7_sepconv3_act (Activation (None, 3, 3, 728)    0           block7_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv3 (SeparableConv2 (None, 3, 3, 728)    536536      block7_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv3_bn (BatchNormal (None, 3, 3, 728)    2912        block7_sepconv3[0][0]            
__________________________________________________________________________________________________
add_6 (Add)                     (None, 3, 3, 728)    0           block7_sepconv3_bn[0][0]         
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1_act (Activation (None, 3, 3, 728)    0           add_6[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1 (SeparableConv2 (None, 3, 3, 728)    536536      block8_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv1_bn (BatchNormal (None, 3, 3, 728)    2912        block8_sepconv1[0][0]            
__________________________________________________________________________________________________
block8_sepconv2_act (Activation (None, 3, 3, 728)    0           block8_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv2 (SeparableConv2 (None, 3, 3, 728)    536536      block8_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv2_bn (BatchNormal (None, 3, 3, 728)    2912        block8_sepconv2[0][0]            
__________________________________________________________________________________________________
block8_sepconv3_act (Activation (None, 3, 3, 728)    0           block8_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv3 (SeparableConv2 (None, 3, 3, 728)    536536      block8_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv3_bn (BatchNormal (None, 3, 3, 728)    2912        block8_sepconv3[0][0]            
__________________________________________________________________________________________________
add_7 (Add)                     (None, 3, 3, 728)    0           block8_sepconv3_bn[0][0]         
                                                                 add_6[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1_act (Activation (None, 3, 3, 728)    0           add_7[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1 (SeparableConv2 (None, 3, 3, 728)    536536      block9_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv1_bn (BatchNormal (None, 3, 3, 728)    2912        block9_sepconv1[0][0]            
__________________________________________________________________________________________________
block9_sepconv2_act (Activation (None, 3, 3, 728)    0           block9_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv2 (SeparableConv2 (None, 3, 3, 728)    536536      block9_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv2_bn (BatchNormal (None, 3, 3, 728)    2912        block9_sepconv2[0][0]            
__________________________________________________________________________________________________
block9_sepconv3_act (Activation (None, 3, 3, 728)    0           block9_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv3 (SeparableConv2 (None, 3, 3, 728)    536536      block9_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv3_bn (BatchNormal (None, 3, 3, 728)    2912        block9_sepconv3[0][0]            
__________________________________________________________________________________________________
add_8 (Add)                     (None, 3, 3, 728)    0           block9_sepconv3_bn[0][0]         
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1_act (Activatio (None, 3, 3, 728)    0           add_8[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1 (SeparableConv (None, 3, 3, 728)    536536      block10_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv1_bn (BatchNorma (None, 3, 3, 728)    2912        block10_sepconv1[0][0]           
__________________________________________________________________________________________________
block10_sepconv2_act (Activatio (None, 3, 3, 728)    0           block10_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv2 (SeparableConv (None, 3, 3, 728)    536536      block10_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv2_bn (BatchNorma (None, 3, 3, 728)    2912        block10_sepconv2[0][0]           
__________________________________________________________________________________________________
block10_sepconv3_act (Activatio (None, 3, 3, 728)    0           block10_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv3 (SeparableConv (None, 3, 3, 728)    536536      block10_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv3_bn (BatchNorma (None, 3, 3, 728)    2912        block10_sepconv3[0][0]           
__________________________________________________________________________________________________
add_9 (Add)                     (None, 3, 3, 728)    0           block10_sepconv3_bn[0][0]        
                                                                 add_8[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1_act (Activatio (None, 3, 3, 728)    0           add_9[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1 (SeparableConv (None, 3, 3, 728)    536536      block11_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv1_bn (BatchNorma (None, 3, 3, 728)    2912        block11_sepconv1[0][0]           
__________________________________________________________________________________________________
block11_sepconv2_act (Activatio (None, 3, 3, 728)    0           block11_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv2 (SeparableConv (None, 3, 3, 728)    536536      block11_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv2_bn (BatchNorma (None, 3, 3, 728)    2912        block11_sepconv2[0][0]           
__________________________________________________________________________________________________
block11_sepconv3_act (Activatio (None, 3, 3, 728)    0           block11_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv3 (SeparableConv (None, 3, 3, 728)    536536      block11_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv3_bn (BatchNorma (None, 3, 3, 728)    2912        block11_sepconv3[0][0]           
__________________________________________________________________________________________________
add_10 (Add)                    (None, 3, 3, 728)    0           block11_sepconv3_bn[0][0]        
                                                                 add_9[0][0]                      
__________________________________________________________________________________________________
block12_sepconv1_act (Activatio (None, 3, 3, 728)    0           add_10[0][0]                     
__________________________________________________________________________________________________
block12_sepconv1 (SeparableConv (None, 3, 3, 728)    536536      block12_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv1_bn (BatchNorma (None, 3, 3, 728)    2912        block12_sepconv1[0][0]           
__________________________________________________________________________________________________
block12_sepconv2_act (Activatio (None, 3, 3, 728)    0           block12_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv2 (SeparableConv (None, 3, 3, 728)    536536      block12_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv2_bn (BatchNorma (None, 3, 3, 728)    2912        block12_sepconv2[0][0]           
__________________________________________________________________________________________________
block12_sepconv3_act (Activatio (None, 3, 3, 728)    0           block12_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv3 (SeparableConv (None, 3, 3, 728)    536536      block12_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv3_bn (BatchNorma (None, 3, 3, 728)    2912        block12_sepconv3[0][0]           
__________________________________________________________________________________________________
add_11 (Add)                    (None, 3, 3, 728)    0           block12_sepconv3_bn[0][0]        
                                                                 add_10[0][0]                     
__________________________________________________________________________________________________
block13_sepconv1_act (Activatio (None, 3, 3, 728)    0           add_11[0][0]                     
__________________________________________________________________________________________________
block13_sepconv1 (SeparableConv (None, 3, 3, 728)    536536      block13_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv1_bn (BatchNorma (None, 3, 3, 728)    2912        block13_sepconv1[0][0]           
__________________________________________________________________________________________________
block13_sepconv2_act (Activatio (None, 3, 3, 728)    0           block13_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block13_sepconv2 (SeparableConv (None, 3, 3, 1024)   752024      block13_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv2_bn (BatchNorma (None, 3, 3, 1024)   4096        block13_sepconv2[0][0]           
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 2, 2, 1024)   745472      add_11[0][0]                     
__________________________________________________________________________________________________
block13_pool (MaxPooling2D)     (None, 2, 2, 1024)   0           block13_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 2, 2, 1024)   4096        conv2d_4[0][0]                   
__________________________________________________________________________________________________
add_12 (Add)                    (None, 2, 2, 1024)   0           block13_pool[0][0]               
                                                                 batch_normalization_4[0][0]      
__________________________________________________________________________________________________
block14_sepconv1 (SeparableConv (None, 2, 2, 1536)   1582080     add_12[0][0]                     
__________________________________________________________________________________________________
block14_sepconv1_bn (BatchNorma (None, 2, 2, 1536)   6144        block14_sepconv1[0][0]           
__________________________________________________________________________________________________
block14_sepconv1_act (Activatio (None, 2, 2, 1536)   0           block14_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block14_sepconv2 (SeparableConv (None, 2, 2, 2048)   3159552     block14_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block14_sepconv2_bn (BatchNorma (None, 2, 2, 2048)   8192        block14_sepconv2[0][0]           
__________________________________________________________________________________________________
block14_sepconv2_act (Activatio (None, 2, 2, 2048)   0           block14_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 2048)         0           block14_sepconv2_act[0][0]       
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 2)            4098        global_average_pooling2d_1[0][0] 
==================================================================================================
Total params: 20,865,578
Trainable params: 12,172,402
Non-trainable params: 8,693,176
__________________________________________________________________________________________________

このモデルを使って、いよいよScrapingで得られた画像を学習させていきます。

2.3 学習の実行

学習実行では、もちろん交差検証を行うことで、汎化性能を持たせたいと思います。 さらに、ImageDataGeneratorで、epoch単位での画像の水増しをリアルタイムで行います。 epoch単位で水増し画像を生成することで、GPUメモリを節約することができます。

datagen = ImageDataGenerator(
  rescale=1./255,
  rotation_range=360, # 90°まで回転
  width_shift_range=0.2, # 水平方向にランダムでシフト
  height_shift_range=0.2, # 垂直方向にランダムでシフト
  shear_range=0.39, # 斜め方向(pi/8まで)に引っ張る
  horizontal_flip=True, # 垂直方向にランダムで反転
  vertical_flip=True, # 水平方向にランダムで反転
  zoom_range=0.2,
  fill_mode='wrap',
)

# 水増し画像を訓練用画像の形式に合わせる
datagen.fit(trainX)

# 画像の水増し量
amprate = 20 * len(trainX) / batch_size * cv_num/(cv_num-1)

# 学習実行
hist = model.fit_generator(
  datagen.flow(
    trainX, trainY, 
    batch_size=batch_size, 
  ),
  steps_per_epoch= amprate,
  epochs=epochs,
  shuffle=True,
  verbose=1,
  validation_data = (evalX, evalY),
  validation_steps=10,
  callbacks=[es,rl]
)

画像の水増し量は、batch sizeやcross validationの回数で動的に決めることにしています。 trainXは元データの画像数のうち、cross validationで分割された訓練用画像数です。つまり、

trainX = 3,786(1,729枚+2,057枚) * (cv_num-1)/cv_num

となります。

batch_sizeを96、cv_numを10と指定した時、画像の水増し量amprateは、

amprate = 20 * 3786*0.9 / 96 * 10/9 = 788.75

となります。したがって、今回の場合は実質

3,786枚×788 = 2,983,368

枚の画像で学習させることになります。学習データとしては、まぁ十分じゃないでしょうか。

また、

callbacks=[es,rl]

としていますが、これはEarlyStoppingとReduceLROnPlateauを指定しています。 いい感じに精度や誤差が収束したら、学習率を徐々に下げたり、学習をストップさせたりするすることで、いろいろとパラメタチューニングを繰り返すための計算効率を向上させることをしています。

2.4 学習モデルの評価

チューニングさせた主なパラメタは、以下の点です。

  • 活性化関数
  • バッチサイズ
  • 初期学習率
  • 交差検証の分割数

2.4.1 バッチサイズの影響

その中でも、活性化関数をsigmoid、初期学習率を0.004、交差検証の分割数を10で固定し、バッチサイズを48~128で振った時のlearning curve(学習曲線)とvalidation curveを見てみましょう。

以下に、バッチサイズのみを変えて学習させたときの学習曲線を示します。

学習曲線コメント
バッチサイズ:48
バッチサイズ:64
バッチサイズ:96
バッチサイズ:128

どの場合でも、精度は 95%程度 の正答率を得られていることがわかります。validation accuracyは、、、、、思ったほど高くなっていませんが、およそ汎化性能として 70%程度 の精度が得られていることがわかります。

バッチサイズは、あまり大きすぎても良くなさそうです。バッチサイズを大きくすると、学習効率が上がります。しかし、その分、より多くの学習データをもとに学習させることになるので、1回の学習でより様々な画像から特徴を見出さなければいけないため、val_accが上がりにくくなっているのか・・・?と想像しています。

次に、損失曲線も見てみることにしましょう。

損失曲線コメント
バッチサイズ:48
バッチサイズ:64
バッチサイズ:96
バッチサイズ:128。このグラフだけ、縦軸が対数となっていることに留意

過学習しているとまでは言えませんが、val_lossが下がりきった感じがしないし、そもそもその値も大きいし、十分に損失を小さくできたとは言えないですね。でもまぁ、そもそもこのプロジェクト自体、自分の勉強も兼ねてのネタだし、精度の追及は、頑張らないことにします。

3. お返し予算の算出

長くなったので、この内容はこちらに整理しました

結論だけ紹介すると、以下のグラフのような結果としました。

Advertisements

コメントを残す