[Python] 漫画のWikipediaの説明文から発表年を推定する その2 どの単語が分類に有用かを調べる

こんにちはLink-Uの町屋敷です。

前回はWikipediaの漫画の説明文から発表年を推定しました。

そこそこ推定できましたが、そもそも漫画の説明文から発表年を推定してなにがうれしいかって特に生産性は無いんですよね、

しかし、学習器自体に生産性が皆無でも学習器がどんな基準で推定したかがわかれば何か別の有意義な情報が得られる可能性があります。

例えば今回で言うと入力は文章中に出てくる単語なので、時代とともに出現する単語の傾向を知ることが出来ます。

また、学習器の精度を上げていくには、入力する特徴量と答えの関係を色んな角度から見てある仮説を立てて、特徴量を削ったり加工していくんですが、そのあたりとして使うことも出来ます。

そこで前回はアンサンブル系の学習機を使ったときに学習に利用されるfeature_importanceというパラメーターで単語の重要度を見ましたが、

今回は別の方法でやってみます、そこで使うのがLIMEです。

LIMEはどんなアルゴリズムで学習器を作っても入力された特徴量の重要度を出力してくれるスグレモノで、マルチクラスにも対応しています。

LIMEを使う

LIMEは回帰問題でも使えるんですが、マルチクラス分類のほうがわかりやすいのでそちらを使います。

もともと回帰の問題だったのでこれをある期間ごとに分け直します。

def _ConvertLabelForClf(labels):
    for i , v in enumerate(labels):
        if v <= 1990:
            labels[i] = 0  
        elif v <= 2005:
            labels[i] = 1
        elif v <= 2010:
            labels[i] = 2 
        else:
            labels[i] = 3
    return labels

今回は1990年以前、1996-2005,2006-2010,2011-の4クラスに分けました。

前回はGradiantBoostingの回帰の学習器を作りましたが、今回は分類の学習器を作ります。

パワメータ調整などやることはほぼ変わらないです。4クラスの数はすべて同じにはならないので、不均等データのための処理(アンダーサンプリングなど)を行う必要も考えましたが、何もせずにとりあえず回した結果がそんなに悪くなかったのでやってません。

ライブラリはpipで入ります。

def LearnClassifier():
    
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
     
    print(np.shape(features))
    print(np.shape(labels))
    
    tf = TfidfTransformer()
    features = tf.fit_transform(features)
    #features = _CompressWithAutoencoder(features)
    labels   = _ConvertLabelForClf(labels)
    
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)
    
    _, counts = np.unique(y_train, return_counts=True)
    weights = counts/len(y_train)
    weights[0] =  1 - weights[1] - weights[2] - weights[3]
    
    clf = LGBMClassifier(
        learning_rate =0.1, n_estimators=1000,
        max_depth=3,
        objective='multiclass',
    )
    
    print(weights)
    print('Learning Start')
    time = timer()
    clf.fit(X_train,y_train)
    time = timer() - time
    print('Learning Finishn Time: {0}'.format(time))
    
    joblib.dump(clf, '{0}/gradient_boosting_classifier.pkl'.format(WRITE_JOBLIB_DIR))
    y_pred = clf.predict(X_test)
    #for yt, yp in zip (y_test, y_pred):
    #    print((yt,yp))
    
    cm = confusion_matrix(y_test, y_pred, [0,1,2,3])
    print('n')
    print(cm)
    
    f1 = f1_score(y_test, y_pred, [0,1,2,3], average='macro')
    acx = accuracy_score(y_train, clf.predict(X_train), [0,1,2,3])
    acy = accuracy_score(y_test, y_pred, [0,1,2,3])
    
    print('f1 = {0},train_accuracy = {1}, test_accuracy = {2}'.format(f1, acx, acy))

コンフュージョンマトリックスとF1値はこんな感じ

そこそこといったところ。この学習器を使ってLIMEを試してみよう。

def InspectClassifier(dict_param = [5, 0.1]):
    dictionary = corpora.Dictionary.load_from_text('filtered_dic_below{0}_above{1}.txt'.format(dict_param[0], dict_param[1]))
    
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    clf = joblib.load('{0}//gradient_boosting_classifier.pkl'.format(WRITE_JOBLIB_DIR))
    
    
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)

まず、データのロード。train_test_splitのシードには先程の学習時に使ったものと同じものを使うこと。

    from lime import lime_tabular
    from sklearn.pipeline import make_pipeline
    c = make_pipeline(clf)
    
    class_names = ['< 1990', '< 2005', '< 2010', '>= 2010']
    
    id2token = np.empty(len(dictionary.token2id), dtype='U64')
    for k, v in dictionary.token2id.items():
        id2token[int(v)] = k
    
    exp = lime_tabular.LimeTabularExplainer(X_train, feature_names=id2token, class_names=class_names)

特徴の列と単語のマッピングにはとgensimで作った辞書はそのまま使えないので加工する。

LimeTextExplainer()があるのにLimeTabularExplainer()を使ってる理由は、テキストのほうに分かち書きされた原文を入力に要求されるため、

英語ならそのままぶち込めばいいが日本がだといちいち分けて原型に戻してやらないといけない。さすがにめんどくさい。あとマルチバイト対応しているのか謎。

あとは公式ドキュメントに沿ってやるだけ。

    idx = 0 
    for idx in range(100):
        lime_result = exp.explain_instance(X_test[idx], c.predict_proba, num_features=100, labels=(0,1,2,3))
        
        lime_result.save_to_file('./exp/exp_{0}_{1}.html'.format(idx, y_train[idx]))
        print('# {0} finished'.format(idx))

公式は結果の表示に

show_in_notebook()

を使ってるけどiPython入れてないと何も起こらないので注意!

また、入力する単語数が多すぎると

こんなのが出て進まない、単語数を減らすために、gensimの辞書生成のパラメーターを調整して単語数を減らす。

試しに単語数を減らす代わりにオートエンコーダーで次元圧縮して突っ込んでみたけど爆死したのでこっちのほうが確実?

結果としてhtmlファイルで出力されるのでブラウザーでみてみよう。

各文章の分類に利用された単語が影響度の高い順に出力されている。

つまり、1990年以下かどうかを表す一眼左の青いグラフが書かれているところに上から順に現像、全集、ビデオとあるが、これは文章中にこれらの単語が一度も出現しないとき、1990年以前でないクラスに文章が分類される可能性が高くなることを示している。条件が0以下つまり文章に無いときなってしまっているのでややこしいが、要は現像とか全集とかの単語がなかったらそんなに昔の作品じゃないんじゃない?ってこと。左上にどの程度の自信があって分類しているかがわかるのも地味に嬉しい。

まとめ

LIMEを使うとどの特徴量がどれだけ分類に貢献しているかがわかりやすい。SVMとかにも使えるのも良い点。

参考サイト

https://github.com/marcotcr/lime

https://qiita.com/fufufukakaka/items/d0081cd38251d22ffebf

[Python] 漫画のWikipediaの説明文から発表年を推定する

こんにちは。Link-Uの町屋敷です。

今回は、テキストデータを解析する一例として、

前回抽出した漫画のWikipediaの文章データを使って、

入力データを説明文、出力データを発表年として、入力データから出力データを推定して行きたいと思います。

また、入力データのどの要素(今回なら単語)がその回帰や分類に効力があるのかを調べる方法も紹介していきたいです。

インフォボックスから発表年のデータを取得する

まず、正解データとして使用する発表年データをインフォボックスから収集しよう。

Wikipediaのインフォボックスはjsonで要素infoboxに文字列として前回保存した。

これを適当なパーサーで処理したら完了! …とはいかない。

まず、前回保存した本文とインフォボックスjsonを結合する。

結合方法は両方のファイルにtitle情報が含まれているので紐づけするだけ。

def LoadTextJsonGenerator(files):
    for jf in files:
        with open(jf) as f:
            json_data = pd.read_json(f, lines=True)
            yield json_data.to_dict()
            
def JoinJsonData(info_json_data):
    text_json_files = sorted(glob(TEXT_JSON_DIR + '/*/*'))
    
    ltjg = LoadTextJsonGenerator(text_json_files)

    text_data = ltjg.__next__()
    for i, ijd in enumerate(info_json_data):    
        while True:
            if not ijd['title']:
                ijd['title'] = 'NULL'
            id = int(ijd['id'])
            if id in text_data['id'].values():
                data_index = list(text_data['id'].values()).index(id)
                info_json_data[i]['text'] = text_data['text'][data_index]
                if ijd['title'] == 'NULL':
                    info_json_data[i]['title'] = text_data['title'][data_index]
                break
            else:
                #print((i,id)) #確認用
                text_data = ltjg.__next__() #StopIterationしたらどっかバグってる

def Preprocess():
    with open(INFOBOX_JSON_DIR + '/' + INFOBOX_FILE_NAME) as infof:
        info_json_data = json.load(infof)
        
    JoinJsonData(info_json_data)
    with open(WRITE_JSON_DIR  + '/' + WRITE_JSON_FILE_NAME, 'w', encoding='utf_8') as jw: 
        json.dump(info_json_data, jw, ensure_ascii=False)

どうやらインフォボックスの要素名は微妙な表記ブレや要素名だけあってデータが入っていないことが割とよくあるので、

表記ブレに対応しつつ、もったいないけどデータが入っていない漫画のデータを捨てる必要がある。

このへんは正規表現で頑張ったらなんとかなる。今回は開始、発表期間、連載期間、発表号に表記ブレしていた。(ソース全文は一番下)

infobox = j['infobox'][0] #複数ある場合でも最初のもののみを使う
publication_year = re.search('| *[開始|発表期間|連載期間|発表号].*?([1|2][8|9|0]dd)[年|.]', infobox, re.MULTILINE | re.DOTALL) #@UndefinedVariable
if not publication_year:
    labels[i] = -1
    vain_count += 1
else:
    labels[i] = publication_year[1]

これで、labelsに年数が入った。データが入っていないときlabelsに-1を入れっておくことで後で対応する文章を消すことができる。

なかなかきれいなポアソン分布。

まあ一定時間で区切ったカウントデータだしね……

テキストを学習機に入れられる形に変形する

次に入力データを加工する

labelsの使えるデータの数を数えると5000ちょいだった。

この数でニューラルネットにに直接ぶち込んでLSTMとかを使って解析するのは厳しいので、BoWに加工する。

具体的には、文章を単語単位に分離して、そのうちの名詞、形容詞、動詞に番号を付けて、文章中のその各単語の出現回数を数える。

with MeCab() as mecab:

文章を単語単位に分離する手法は形態素解析と言われるが、ライブラリで簡単にできる。有名なのはChasenとかMecabだが、今回はMecabとpythonでMecabを使えるようなするNatto-pyを使う。インストール方法は調べたら山程出てくるので省略。

            words = []
            text = j['text']
            for mp in mecab.parse(text, as_nodes=True):
                if not mp.is_eos():
                    feature_splits = mp.feature.split(',')
                    if feature_splits[0] in ['名詞', '動詞', '形容詞']:
                        if feature_splits[1] in ['数']:
                            continue
                        elif feature_splits[2] in ['人名']:
                            continue
                        elif feature_splits[6] in ['*']:
                            continue
                        words.append(feature_splits[6])

if feature_splits[0] in [‘名詞’, ‘動詞’, ‘形容詞’]:の行で品詞を絞って、

その後、出現する単語のうち3とか2001などの数は、答えを書いている可能性があるので除去、

また、人名もその人が活躍する時期はある程度偏ってるはずなので、ほぼ答えになるじゃんと言う事で削除した。

def MakeDict(all_words):
    dictionary = corpora.Dictionary(all_words)
    print(dictionary.token2id)
    for no_below in [5,20,40]:
        for no_above in [0.1,0.3,0.5]:
            dictionary.filter_extremes(no_below=no_below, no_above=no_above)
            dictionary.save_as_text('filtered_dic_below{0}_above{1}.txt'.format(no_below, no_above))

単語への番号付けは、専用の辞書を作って行う。

これもgensimというライブラリがある。Mecabと同様これも情報は大量にあるので省略(例えばここここ

def MakeDict(all_words):
    dictionary = corpora.Dictionary(all_words)
    print(dictionary.token2id)
    for no_below in [5,20,40]:
        for no_above in [0.1,0.3,0.5]:
            dictionary.filter_extremes(no_below=no_below, no_above=no_above)
            dictionary.save_as_text('filtered_dic_below{0}_above{1}.txt'.format(no_below, no_above))

辞書を作る部分がここで、no_belowやno_aboveで作り分けたが今回のデータセットでは違いはほぼなかったので、no_below=5, no_above=0.1を使うことにする。

def MakeFeatures(make_dict = False, dict_param = [5, 0.1]):
    
    all_words = joblib.load('{0}/all_wordss.pkl'.format(WRITE_JOBLIB_DIR))
    labels = joblib.load('{0}/publication_years.pkl'.format(WRITE_JOBLIB_DIR))
    if (make_dict):
        MakeDict(all_words)
    
    dictionary = corpora.Dictionary.load_from_text('filtered_dic_below{0}_above{1}.txt'.format(dict_param[0], dict_param[1]))
    
    dl = len(dictionary)
    features = []
    for w, l in zip(all_words, labels):
        tmp = dictionary.doc2bow(w)
        dense = list(matutils.corpus2dense([tmp], num_terms=len(dictionary)).T[0])
        if not l == -1:
            features.append(dense)
    features = np.array(features)
    features = np.reshape(features, (-1, dl))
    labels = [int(v) for v in labels if v != -1]
    joblib.dump(features, '{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    joblib.dump(labels, '{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))

ここで、先程作った辞書を使ってBoWに変換し、labelが-1のデータを削除する

ここでTFIDFを使ってもいいが、今回はパス。これでデータセットの成形は完了。

集計した単語のうち数が多かったTop50をおいておく。

データの学習

本来はTPOTとかを駆使していろいろな学習機で調査するべきだが、時間がないので決定木グラディアントブースティングマシーン(以下、GBM)を用いた回帰と比較対象用に線形回帰を行う。

グラディアントブースティングを使用できるライブラリはsklearn,XGBoost,lightgbmと大きく3つあるが、機能の多さと実行の速さを考えるとlightgbmが良い。

def TuneXgboostRgr():
    
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
    
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)
    
    params = { 
        'learning_rate' : [0.3,0.2,0.1]
    }
    
    gs = GridSearchCV(estimator = LGBMRegressor(
                            num_leaves=31,
                            learning_rate =0.1, 
                            n_estimators=1000,
                            max_depth=9,
                            objective='regression',
                            min_sum_hessian_in_leaf=1
                        ),
                        param_grid = params, 
                        cv=5)

    gs.fit(X_train, y_train)
    print('n')
    print(gs.cv_results_)
    print('n')
    print(gs.best_params_)
    print('n')
    print(gs.best_score_)

まず、ハイパーパラメーターのチューニングをする。paramsにチューニングしたい変数と値を配列で入れるとクロスバリデーションもして一番いいパラメーターを調査してくれる。

ただし、paramsに大量の変数を設定すると永遠に終わらなくなるので、1,2種類ずついれて何回もやる。詳しくはグリッドサーチで検索!

def LearnRegressor(clf_name = 'gbm'):
    from sklearn.linear_model import LinearRegression
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
    
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)
    
    if clf_name == 'gbm':
        rgr = LGBMRegressor( 
            num_leaves=1000, 
            max_depth=9, 
            learning_rate=0.06, 
            n_estimators=100, 
            objective='regression',
        )
    elif clf_name == 'linear':
        rgr = LinearRegression()
    #rgr = XGBRegressor( n_estimators = 100, learning_rate=0.1)
    
    print('Learning Start')
    time = timer()
    rgr.fit(X_train,y_train)
    time = timer() - time
    print('Learning Finishn Time: {0}'.format(time))
    y_pred = rgr.predict(X_test)
    
    joblib.dump(rgr, '{0}/{1}_regressor.pkl'.format(WRITE_JOBLIB_DIR, clf_name))
    for yt, yp in zip (y_test, y_pred):
        print((yt,yp))
    print(mean_squared_error(y_test, y_pred))

ハイパーパラメーターのチューニングが終了したら、テストデータで評価する。

結果がこちら、左が正解、中央がGBM,右が線形回帰

比べてみると線形回帰はたまに未来や19世紀などのありえない数値を予測している。

実際誤差を計算するとGBM 59.1に対して線形回帰は329.5と大差を付けている。

(  正解, GBM , linear)

(2012, 2008, 2011)
(2008, 2011, 2007)
(2004, 2001, 2002)
(2010, 2006, 2025)
(2004, 2006, 2003)
(2006, 2004, 2012)
(2017, 2012, 2041)
(2006, 2007, 2005)
(2008, 2006, 1999)
(1977, 1985, 1987)
(2003, 2006, 2011)
(2008, 2006, 1986)
(2015, 2005, 1998)
(1992, 1998, 2003)
(2017, 2008, 2012)
(2003, 2006, 2000)
(2010, 2006, 2011)
(2010, 2006, 2011)
(2004, 2003, 2002)
(2010, 2006, 1958)
(1985, 2002, 1997)
(1994, 2003, 1988)
(1990, 1996, 2004)
(1999, 2002, 2006)
(1984, 1995, 2010)
(1987, 2000, 1993)
(1962, 2001, 2000)
(1993, 1992, 2005)
(1997, 2006, 2009)
(1988, 1982, 1989)
(1972, 1989, 1959)
(1987, 1999, 1989)
(1968, 1973, 1967)
(1970, 1982, 1970)
(1991, 2000, 2000)
(1998, 2007, 1994)
(1999, 1998, 1996)
(1994, 1993, 2001)
(1968, 1970, 1872)

回帰を行う学習器を生成できました。

これで、入力データのどの要素がその回帰や分類に効力があるのかを調べることが出来ます。

線形回帰は相関係数を計算すればすぐに見れます。

実はGBMのようなブースティングモデルの場合でも同様に簡単に調査できて、

そのまま重要度として保存されています。(参考)

両者の結果を見比べてみましょう。

線形回帰の方は法則がそんなに見えません、いんするとかいう謎単語も含まれてるし……

でも、機械学習で出した結果です!って言われたらこんなもんかとは思いそう。

GBMのほうは、昭和や平成、ビデオ、インターネットや現像といった明らかに時代に関係のあるものが紛れているのがわかります。

やたら社名が入っているのは何故だろう……

でもどちらの信用度が高いかは明らかでしょう。

まとめ

今回は漫画のWikipediaの文章データを使って、発表年を推定する方法を紹介しました。

学習器がどの要素(今回なら単語)を注視しているかも調べましたがある単語が強く関係あるとわかったところであまり使い道がありません。

次回では、別の方法でもっといい調査の方法を紹介します。

プログラム全文

import json
from glob import glob
from functools import reduce
import re
from timeit import default_timer as timer

import numpy as np
import pandas as pd
from natto import MeCab
import joblib
from gensim import corpora, matutils

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties

from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import GradientBoostingClassifier

#Xgboostは重いので使わない(5倍くらい違う)
#from xgboost import XGBRegressor
#from xgboost import XGBClassifier

from lightgbm import LGBMRegressor
from lightgbm import LGBMClassifier

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

TEXT_JSON_DIR = '../WikipediaComic/whole_data'
INFOBOX_JSON_DIR = '.'
INFOBOX_FILE_NAME = 'wiki_infobox_Infobox_animanga_Manga.json'
WRITE_JSON_DIR = '.'
WRITE_JSON_FILE_NAME = 'joined.json' 

WRITE_JOBLIB_DIR = '.'

#Matplotlibの日本語設定
font_path = '/usr/share/fonts/truetype/takao-gothic/TakaoPGothic.ttf'
font_prop = FontProperties(fname=font_path)
matplotlib.rcParams['font.family'] = font_prop.get_name()

def LoadTextJsonGenerator(files):
    for jf in files:
        with open(jf) as f:
            json_data = pd.read_json(f, lines=True)
            yield json_data.to_dict()
            
def JoinJsonData(info_json_data):
    text_json_files = sorted(glob(TEXT_JSON_DIR + '/*/*'))
    
    ltjg = LoadTextJsonGenerator(text_json_files)

    text_data = ltjg.__next__()
    for i, ijd in enumerate(info_json_data):    
        while True:
            if not ijd['title']:
                ijd['title'] = 'NULL'
            id = int(ijd['id'])
            if id in text_data['id'].values():
                data_index = list(text_data['id'].values()).index(id)
                info_json_data[i]['text'] = text_data['text'][data_index]
                if ijd['title'] == 'NULL':
                    info_json_data[i]['title'] = text_data['title'][data_index]
                break
            else:
                #print((i,id)) #確認用
                text_data = ltjg.__next__() #StopIterationしたらどっかバグってる

def Preprocess():
    with open(INFOBOX_JSON_DIR + '/' + INFOBOX_FILE_NAME) as infof:
        info_json_data = json.load(infof)
        
    JoinJsonData(info_json_data)
    with open(WRITE_JSON_DIR  + '/' + WRITE_JSON_FILE_NAME, 'w', encoding='utf_8') as jw: 
        json.dump(info_json_data, jw, ensure_ascii=False)
    
def ExtractWords():
    with open(WRITE_JSON_DIR  + '/' + WRITE_JSON_FILE_NAME, 'r', encoding='utf_8') as jw: 
        json_data = json.load(jw)
    
    all_words = [[0]] * len(json_data) 
    with MeCab() as mecab:
        for i, j in enumerate(json_data):
            if i % 100 == 0:
                print(i)
            words = []
            text = j['text']
            for mp in mecab.parse(text, as_nodes=True):
                if not mp.is_eos():
                    feature_splits = mp.feature.split(',')
                    if feature_splits[0] in ['名詞', '動詞', '形容詞']:
                        if feature_splits[1] in ['数']:
                            continue
                        elif feature_splits[2] in ['人名']:
                            continue
                        elif feature_splits[6] in ['*']:
                            continue
                        words.append(feature_splits[6])
            all_words[i] = words
            
    joblib.dump(all_words, '{0}/all_wordss.pkl'.format(WRITE_JOBLIB_DIR), compress = True)

def ExtractLabelFromInfobox():
    with open(WRITE_JSON_DIR  + '/' + WRITE_JSON_FILE_NAME, 'r', encoding='utf_8') as jw: 
        json_data = json.load(jw)    
        labels = [0] * len(json_data)
        vain_count = 0
        for i, j in enumerate(json_data):
            if i % 100 == 0:
                print(i)
            infobox = j['infobox'][0] #複数ある場合でも最初のもののみを使う
            publication_year = re.search('| *[開始|発表期間|連載期間|発表号].*?([1|2][8|9|0]dd)[年|.]', infobox, re.MULTILINE | re.DOTALL) #@UndefinedVariable
            if not publication_year:
                labels[i] = -1
                vain_count += 1
            else:
                labels[i] = publication_year[1]
        joblib.dump(labels, '{0}/publication_years.pkl'.format(WRITE_JOBLIB_DIR), compress = True)
    print(vain_count)
     
def MakeDict(all_words):
    dictionary = corpora.Dictionary(all_words)
    print(dictionary.token2id)
    for no_below in [5,20,40]:
        for no_above in [0.1,0.3,0.5]:
            dictionary.filter_extremes(no_below=no_below, no_above=no_above)
            dictionary.save_as_text('filtered_dic_below{0}_above{1}.txt'.format(no_below, no_above))
            
def MakeFeatures(make_dict = False, dict_param = [5, 0.1]):
    
    all_words = joblib.load('{0}/all_wordss.pkl'.format(WRITE_JOBLIB_DIR))
    labels = joblib.load('{0}/publication_years.pkl'.format(WRITE_JOBLIB_DIR))
    if (make_dict):
        MakeDict(all_words)
    
    dictionary = corpora.Dictionary.load_from_text('filtered_dic_below{0}_above{1}.txt'.format(dict_param[0], dict_param[1]))
    
    dl = len(dictionary)
    features = []
    for w, l in zip(all_words, labels):
        tmp = dictionary.doc2bow(w)
        dense = list(matutils.corpus2dense([tmp], num_terms=len(dictionary)).T[0])
        if not l == -1:
            features.append(dense)
    features = np.array(features)
    features = np.reshape(features, (-1, dl))
    labels = [int(v) for v in labels if v != -1]
    joblib.dump(features, '{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    joblib.dump(labels, '{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
            


def LearnRegressor(clf_name = 'gbm'):
    from sklearn.linear_model import LinearRegression
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
    
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)
    
    if clf_name == 'gbm':
        rgr = LGBMRegressor( 
            num_leaves=1000, 
            max_depth=9, 
            learning_rate=0.06, 
            n_estimators=100, 
            objective='regression',
        )
    elif clf_name == 'linear':
        rgr = LinearRegression()
    #rgr = XGBRegressor( n_estimators = 100, learning_rate=0.1)
    
    print('Learning Start')
    time = timer()
    rgr.fit(X_train,y_train)
    time = timer() - time
    print('Learning Finishn Time: {0}'.format(time))
    y_pred = rgr.predict(X_test)
    
    joblib.dump(rgr, '{0}/{1}_regressor.pkl'.format(WRITE_JOBLIB_DIR, clf_name))
    for yt, yp in zip (y_test, y_pred):
        print((yt,yp))
    print(mean_squared_error(y_test, y_pred))

def TuneXgboostRgr():
    
    features = joblib.load('{0}/features.pkl'.format(WRITE_JOBLIB_DIR))
    labels  = joblib.load('{0}/labels.pkl'.format(WRITE_JOBLIB_DIR))
    
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.1, random_state=1234)
    
    params = { 
        'learning_rate' : [0.3,0.2,0.1]
    }
    
    gs = GridSearchCV(estimator = LGBMRegressor(
                            num_leaves=31,
                            learning_rate =0.1, 
                            n_estimators=1000,
                            max_depth=9,
                            objective='regression',
                            min_sum_hessian_in_leaf=1
                        ),
                        param_grid = params, 
                        cv=5)

    gs.fit(X_train, y_train)
    print('n')
    print(gs.cv_results_)
    print('n')
    print(gs.best_params_)
    print('n')
    print(gs.best_score_)

if __name__ == '__main__':
    #Preprocess()
    #ExtractWords()
    #ExtractLabelFromInfobox()
    #MakeFeatures()
    #LearnRegressor()
    InspectRegressor()
    #TuneXgboostRgr()
    
    
    
    
    

参考サイト

https://qiita.com/conta_/items/4b031a44acceb137ec73https://yubais.net/doc/matplotlib/bar.html

https://qiita.com/buruzaemon/items/975027cea6371b2c5ec3https://qiita.com/hoto17296/items/e1f80fef8536a0e5e7dbhttps://qiita.com/yasunori/items/31a23eb259482e4824e2

Wikipediaのdumpからinfoboxの内容や文章を取ってくる方法

こんにちは!

Link-Uの町屋敷です。

今回はWikipediaの本文を収集する方法と特定のInfoboxを収集する方法を書いていきます。

Wikipediaから文章を取ってくる

Wikipediaの文章を取ってくる方法は主に以下の2つです。

  1. MediaWikiのAPIを使う
  2. 同じくMediaWikiが提供しているXML形式のダンプファイルを使う。

APIを使う方法のほうが簡単ですが、Wikipediaはクロールが禁止されているので、データセットの作成には方法2を使わざるを得ません。チャットボットとかなら1で良いでしょう。

Wikipedia Extractorを使う

今回はWikipediaExtracctorを使って本文を取得します。

まず、最新のMediaWikiが提供しているXML形式のダンプファイルの中から、

jawiki-latest-pages-articles.xml.bz2

をダウンロード

次に適当なフォルダに移動して

git clone git@github.com:attardi/wikiextractor.git
cd wikiextractor
python python setup.py install

でwikiextractorをインストール出来ます。

自分はminicondaの仮想環境に入ってからやりました。

早速本文を抽出してみましょう。

WikiExtractor.py --json --keep_tables -s --lists -o ../whole_data/ ~/Downloads/jawiki-latest-pages-articles.xml

コマンド説明
--json        : 出力をjsonにする
--keep_tables : 文章内の表を出力するようにする
-s            : セクション情報も出力するようにする
--list        : 文章内のリストを出力するようにする
-o            : 出力ファルダ名

最後のパスがさっきダウンロードしたxmlの入力ファイルです。

このフォイルはダウンロード時bz2という謎の形式で圧縮されています。

wikiextractorで使う分には圧縮されたままで良いのですが、後で中身を見る必要が出てくるので先に解凍しちゃいましょう。linuxなら

bzip2 -d jawiki-latest-pages-articles.xml.bz2 

で解凍できます。

Windowsの場合はLhaplusを使うと良いでしょう。(なぜか7zipだと失敗した)

少し話がそれましたが、先程のスクリプトを実行するとわらわらとファイルが生成されて、

中を見ると文章が取得できているのでめでたしめでたし。

と、思いきやよく見るInfoboxが取得できていない。

Infoboxとはなんぞやと言う事ですが、要はWikipedia見てると右側にわりといる情報が書かれた表です。

これ機械学習だと正解ラベルになるようなかなり重要なこと書かれてるのになんで無いの?

最初–keep_tablesのオプション使っていなくて絶対コレじゃんと思ってやってみたが増えたのは文章の途中に出てくる表だけだった。残念。

XMLから特定のInfoboxの情報を取得する

Infoboxの情報がどうしてもほしいのでネットで色々検索したが、出てくるのはAPIを使ったものばかり。

WikipediaのXMLをJSONに変換するツールなども試してみたがうまく動かず、

そんなこんなで1日くらいあがいても良いライブラリは見つからなかった。

こうなったら自分で作るしか無い

というわけで、XMLから特定のInfoboxの情報を取得してjsonで保存するスクリプトをPython3で作成した。

通常の文章はさっきほど抽出したものがあるので、必要なInfoboxが存在する項目だけタイトルとタイトルIDとセットで保存して、使うときに結合させるようにする。

どうせやるなら一緒に本文も保存したら良いじゃないかってなりそうだが、xml本文中にはいらないものが大量に含まれていて、消す作業が面倒だったので保留。

幸いあがいてる途中に参考になるサイトを見つけたので、大体はこれに沿ってやります。

プログラム全文はページの一番下。特に特別なパッケージは必要ない

大事なところだけ解説すると、

PythonでXMLからタグを取ってくる方法は、普通は

import xml.etree.ElementTree as ET
tree = ET.parse('country_data.xml')
root = tree.getroot()

こんな感じだが、今回はxmlファイルが馬鹿でかいのでこれを行うとメモリが死ぬ可能性がある。

なので、イテレータを用いて処理する。

for event, elem in etree.iterparse(pathWikiXML, events=('start', 'end')): 

処理は上のタグから一つずつ進行していく、例えば

<a>
    <b>0</b>
    <c>
      <d>70334050</d>
    </c>
</a>

このようなタグが来た時は、elem.tagにはa->b->b->c->d->d->c->aの順番でデータが入り、

開きタグなのか閉じタグなのかは,eventを見て判断する。

elem.textにそのタグに囲まれた部分の文章がひるようだ、

ここで、実際のxmlがどんな形式ななっているか確認しましょう。

less ~/Downloads/jawiki-latest-pages-articles.xml

巨大なファイルなので通常のテキストエディタでは開きません。

どうやら<page>~</page>でwikipediaの各ページが表現されていて、

<title>にタイトル

<id>にページ番号

<text>にメインの文章が書かれているようだ。

ただ、Infoboxは少々面倒で、特定のタグに囲われているわけでもなく、<text></text>の中に

{{東京都の特別区
|画像 = [[File:Sensoji 2012.JPG|200px]]
|画像の説明 = [[浅草寺]]境内
|区旗 = [[File:Flag of Taito, Tokyo.svg|100px]]
|区旗の説明 = 台東[[市町村旗|区旗]]&lt;div style=&quot;font-size:smaller&quot;&gt;[[1965年]][[6月4日]]制定
|区章 = [[File:東京都台東区区章.svg|75px]]
|区章の説明 = 台東[[市町村章|区章]]&lt;div style=&quot;font-size:smaller&quot;&gt;[[1951年]][[4月18日]]制定&lt;ref&gt;「東京都台東区紋章制定について」昭和26年4月18日台東区告示第47号&lt;/ref&gt;
|自治体名 = 台東区
|コード = 13106-7
|隣接自治体 = [[千代田区]]、[[中央区 (東京都)|中央区]]、[[文京区]]、[[墨田区]]、[[荒川区]]
|木 = [[サクラ]]
|花 = [[アサガオ]]
|郵便番号 = 110-8615
|所在地 = 台東区[[東上野]]四丁目5番6号&lt;br /&gt;&lt;small&gt;{{ウィキ座標度分秒|35|42|45.4|N|139|46|47.9|E|region:JP-13_type:adm3rd|display=inline,title}}&lt;/small&gt;&lt;br /&gt;[[File:Taito Ward Office.JPG|250px|台東区役所庁舎(東上野四丁目)]]
|外部リンク = [http://www.city.taito.lg.jp/ 台東区]
|位置画像 = {{基礎自治体位置図|13|106}}
|特記事項 =}}

突然現れるだけで文章からの抽出が必要。

どうやら<text>中で{{}}で表されているものは複数あるらしく、単純に{{}}に囲われた部分を抽出してもうまく行かない。

そこでInfoboxの名前を使う。名前は{{のすぐ後に書いているやつで、lessでそのInfoboxが使われている項目を検索して直接確認するか、Wikipediaの基本情報テンプレートに書いてあるTemplate:Infobox ~ の項目を確認する。

                    elem.text = re.sub('{{[F|f]lagicon.*?}}', '', elem.text)
                    infobox = re.findall('{{{0}n.*?|.*?}}'.format(INFOBOX_SEARCH_WORD), elem.text, re.MULTILINE | re.DOTALL)

ここがその抽出部分で、INFOBOX_SEARCH_WORDにInfoboxの名前を入れるとその名前に対応するInfoboxを抽出できる。ただし、{{}}が入れ子になっているとバグるので、re.subで事前にいらないものを消している。

reの代わりにregexを使うと入れ子でも処理できるようだが、今回選んだInfoboxではこれで大丈夫だったので保留。

試しに上のWikipedia画面のキャプチャで囲んでいた漫画のInfobox(Infobox名 animanga/Manga)で試してみた結果がこちら。

ちゃんと取得できていることが確認できた。

参考サイト

https://www.heatonresearch.com/2017/03/03/python-basic-wikipedia-parsing.html

プログラム全文

import xml.etree.ElementTree as etree
import time
import os
import json
import re
import collections as cl

PATH_WIKI_XML = '/home/machiyahiki/Downloads/'
FILENAME_WIKI = 'jawiki-latest-pages-articles.xml'
JSON_SAVE_DIR = '.'
INFOBOX_SEARCH_WORD = 'Infobox animanga/Manga'

ENCODING = "utf-8"

def strip_tag_name(t):
    idx = t.rfind("}")
    if idx != -1:
        t = t[idx + 1:]
    return t

pathWikiXML = os.path.join(PATH_WIKI_XML, FILENAME_WIKI)

totalCount = 0
articleCount = 0
redirectCount = 0
templateCount = 0
title = None
start_time = time.time()
dict_array = []

with open('{0}//wiki_infobox_{1}.json'.format(JSON_SAVE_DIR, re.sub('[ |//]', '_',  INFOBOX_SEARCH_WORD)),'w', encoding='utf_8') as jf:
    for event, elem in etree.iterparse(pathWikiXML, events=('start', 'end')):
        tname = strip_tag_name(elem.tag)
    
        if event == 'start':
            if tname == 'page':
                inrevision = False
                findinfobox = False
                data_dict = cl.OrderedDict()
            elif tname == 'revision':
                # Do not pick up on revision id's
                inrevision = True
            if tname == 'title':
                data_dict['title'] = elem.text
            elif tname == 'id' and not inrevision:
                data_dict['id'] = elem.text
        else:
            if tname == 'text':
                if elem.text:
                    elem.text = re.sub('{{[U|u]nicode.*?}}', '', elem.text)
                    elem.text = re.sub('{{[F|f]lagicon.*?}}', '', elem.text)
                    infobox = re.findall('{{{0}n.*?|.*?}}'.format(INFOBOX_SEARCH_WORD), elem.text, re.MULTILINE | re.DOTALL)
                    if infobox:
                        findinfobox = True
                        data_dict['infobox'] = infobox
            if tname == 'page':
                if findinfobox:
                    if data_dict['title']: #タイトル名に{{Unicode}}があるとnullが入る、後で修可
                        if 'プロジェクト:' in data_dict['title']: #プロジェクト:から始まる項目は無視
                            continue
                        if 'Template:Infobox' in data_dict['title']: #Template:Infoboxから始まる項目も無視
                            continue
                    dict_array.append(data_dict)
                    
                
        if len(dict_array) > 1 and (len(dict_array) % 10000) == 0:
            print("{:,}".format(len(dict_array)))

        
        elem.clear()
    json.dump(dict_array, jf, ensure_ascii=False)
    
    
elapsed_time = time.time() - start_time

print("Total pages: {:,}".format(len(dict_array)))
print("Elapsed Time: {:,}".format(elapsed_time))   
                
    

Pythonで次元圧縮する方法

こんにちは、Link-Uの町屋敷です。

今回は次元圧縮について書いていこうと思います。

データの次元数が多いとどうなるのか

次元の呪いという単語を機械学習では度々目にします。
入力するデータの次元数が多いとモデルに対して与えられる点が相対的に少なくなっていろいろ不都合が出るとか、単純に計算量が多くなってやばいといったもので、
計算が終わらないから次元圧縮するという流れになるんですが。
そもそも使用する頻度の高いSVMなどの学習機で実際どのくらい終わらなくなるのか計測したことがなかったので、
次元圧縮の話の前にまずやってみました。

今回使用するデータは生成データで2クラス分類問題を行います。

DIMENTION_MAX = 2000
SEED = 12345
np.random.seed(SEED)

n_features = 1000
label = []

#ラベルの生成
for i in range(int(n_features/2)):
    label.append(0)
    label.append(1)
    
label = np.array(label)

dim = 100
#次元数がMAXを超えるまでループ
while dim < DIMENTION_MAX:
    feature = []
    x_axis.append("{0}".format(dim))
    
    for i in range(dim):
        #データ生成用の変数を作る
        mu_a = np.random.normal(0,1)
        omega_a = np.random.rand()
        
        mu_b = np.random.normal(0,1)
        omega_b = np.random.rand()
        
        for j in range(int(n_features/2)):
            #上がラベル0、下がラベル1
            feature.append(np.random.normal(mu_a,omega_a) + np.random.normal(0,10*dim/DIMENTION_MAX))
            feature.append(np.random.normal(mu_b,omega_b) + np.random.normal(0,10*dim/DIMENTION_MAX))
            
    #特徴量を要素数*次元数の形に変換        
    feature = np.reshape(feature, (dim, n_features))
    feature = feature.T

具体的には上コードで正規分布を使って作ったデータです。

最初と最後の特徴量だけ取ってきてラベルごとに色を変えてプロットするとこんな感じ。

実データを使うとデータごとに違った結果にまずなりますので今回の結果はあくまで参考です。
次元数は100から始めて2000まで100ずつ増加させてそれぞれ1000個ずつデータを作ります。

次に各次元で生成した特徴量を学習データとテストデータに分割するんですが、

sklearn.model_selection.train_test_split

をそのまま使うと分割したデータのラベルの数が偏ってしまうので、今回はここから関数を拝借してきて使ってます。

それをSVM,ロジスティック回帰、ランダムフォレスト、ニューラルネットで学習して、次元数ごとに学習にかかった時間と学習器の良さを計るF1値をノートパソコンとサーバーで計測しました。

ただし、scikit-learnがGPUに対応していないので、ニューラルネットのみGPUを使っています。

結果は以下の通りで左側のy軸が計算にかかった時間、右側のy軸がF1の値でx軸は次元数です。

このように次元数が増えると時間がかかるだけでなく、精度も落ちてしまうアルゴリズムが出てきてしまいます。

特にSVMは学習に使うデータ量より多い次元を分類しようとすると精度が絶望的になります。

このような問題を回避するためには、データ量を増やす他に、次元を削減したり、データのすべての次元を使うのではなく有用だろうと思われるものを選んで学習に使う必要があります。

次元圧縮

今回は次元を圧縮する方法として、主成分分析(PCA)線形判別分析(LDA)を使います。

タスクが簡単なので線形な変換をするやつだけ。

文章トピック分類に用いるLDAではないので注意。

PCAには教師データが必要ありませんが、LDAには必要です。

LDAは教師データを使って最もよくクラスタを分類するように新しい軸を検索します。

一方でPCAには教師データが必要ありません。

PCAとLDAの違いはここがわかりやすい(英語)

PCA,LDAともにskikit-learnに実装されているので簡単に行うことが出来ます。

from sklearn.decomposition import PCA
    
    dcp = PCA(n_components=10) #n_componentsで削減する次元数を指定
    dcp_data = dcp.fit_transform(data)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
    dcp = LinearDiscriminantAnalysis(n_components=10) #n_componentsで削減する次元数を指定
    dcp_data = dcp.fit_transform(data, label) #教師データとしてlabelを渡す

上に載せてあるデータにPCAをかけて2次元に圧縮すると結果はこんな感じになります。

LDAをかけて1次元に圧縮する(カラスが2つなので1次元になる)とこんな感じ。

PCAを使って圧縮時間も含めて処理時間とF1値を計算した結果がこちら。

比較用に圧縮してないときの結果も載せた。

F1は1が最高ですべてのテストデータを正解したことを示す。

全体的に処理速度が早くなり、F1の値が向上している。

ニューラルネットだけ時間が伸び気味なのは何故だろうか…

まとめ

今回は次元の圧縮で処理が早くなり、精度も向上することを示したが、

実データじゃない簡単タスクだからかもしれない。

isomapとかの非線形の次元圧縮方法や、t-sne、オートエンコーダも紹介していないので、

もっと難しいデータでいつか紹介したい。

使用したソースコード

#Python3系じゃないとバグります

import time
from matplotlib import pyplot as plt
import numpy as np
from sklearn.decomposition import PCA, FastICA
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.mixture import GaussianMixture
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, f1_score
from sklearn.manifold import TSNE

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten

import joblib

#Quoted from https://stackoverflow.com/questions/35472712/how-to-split-data-on-balanced-training-set-and-test-set-on-sklearn
def get_safe_balanced_split(target, trainSize=0.8, getTestIndexes=True, shuffle=False, seed=None):
    classes, counts = np.unique(target, return_counts=True)
    nPerClass = float(len(target))*float(trainSize)/float(len(classes))
    if nPerClass > np.min(counts):
        print("Insufficient data to produce a balanced training data split.")
        print("Classes found %s"%classes)
        print("Classes count %s"%counts)
        ts = float(trainSize*np.min(counts)*len(classes)) / float(len(target))
        print("trainSize is reset from %s to %s"%(trainSize, ts))
        trainSize = ts
        nPerClass = float(len(target))*float(trainSize)/float(len(classes))
    # get number of classes
    nPerClass = int(nPerClass)
    print("Data splitting on %i classes and returning %i per class"%(len(classes),nPerClass ))
    # get indexes
    trainIndexes = []
    for c in classes:
        if seed is not None:
            np.random.seed(seed)
        cIdxs = np.where(target==c)[0]
        cIdxs = np.random.choice(cIdxs, nPerClass, replace=False)
        trainIndexes.extend(cIdxs)
    # get test indexes
    testIndexes = None
    if getTestIndexes:
        testIndexes = list(set(range(len(target))) - set(trainIndexes))
    # shuffle
    if shuffle:
        np.random.shuffle(trainIndexes)
        if testIndexes is not None:
            np.random.shuffle(testIndexes)
    # return indexes
    return trainIndexes, testIndexes
#Quote end

DIMENTION_MAX = 2000
SEED = 12345
np.random.seed(SEED)

n_features = 1000
label = []

#ラベルの生成
for i in range(int(n_features/2)):
    label.append(0)
    label.append(1)
durations = []
f1_values = []
label = np.array(label)
for clf_name in ['SVM', 'LogisticRegression', 'RandomForest', 'NeuralNet']:

    x_axis = []
    for decomp_name in ['LDA']:#['LDA', 'ICA']:
        dim = 1000
        #次元数がMAXを超えるまでループ
        while dim < DIMENTION_MAX:
            feature = []
            x_axis.append("{0}".format(dim))
            
            for i in range(dim):
                #データ生成用の変数を作る
                mu_a = np.random.normal(0,1)
                omega_a = np.random.rand()
                
                mu_b = np.random.normal(0,1)
                omega_b = np.random.rand()
                
                for j in range(int(n_features/2)):
                    #上がラベル0、下がラベル1
                    feature.append(np.random.normal(mu_a,omega_a) + np.random.normal(0,10*dim/DIMENTION_MAX))
                    feature.append(np.random.normal(mu_b,omega_b) + np.random.normal(0,10*dim/DIMENTION_MAX))
            
                    
            #特徴量を要素数*次元数の形に変換        
            feature = np.reshape(feature, (dim, n_features))
            feature = feature.T

            
            #バランス良く訓練データとテストデータに分割
            train_index, test_index = get_safe_balanced_split(label, trainSize=0.8, getTestIndexes=True, shuffle=True, seed=SEED)
           
            train_feature, test_feature = feature[train_index], feature[test_index]
            train_label, test_label = label[train_index], label[test_index]
            
            #学習器の選択
            if clf_name == 'SVM':
                clf = SVC(C = 1, gamma = 0.01)
            elif clf_name == 'LogisticRegression':
                clf = LogisticRegression(C = 1e5)
            elif clf_name == 'RandomForest':
                clf = ExtraTreesClassifier(n_estimators=1000,random_state=0)
            
            start = time.time()
            
            #次元圧縮法の選択
            if decomp_name == 'PCA':
                dec = PCA(n_components = 2)
                train_feature = dec.fit_transform(train_feature)    
                test_feature  = dec.transform(test_feature)
            elif decomp_name == 'LDA':
                dec = LinearDiscriminantAnalysis(n_components=1)
                train_feature = dec.fit_transform(train_feature, train_label)
                test_feature  = dec.transform(test_feature)
            elif decomp_name == 'ICA':
                dec = FastICA(n_components = 2)
                train_feature = dec.fit_transform(train_feature)
                test_feature  = dec.transform(test_feature)
            elif decomp_name == 'TSNE':
                dec = TSNE(n_components=2, perplexity=30)
                train_feature = dec.fit_transform(feature)
            
            if clf_name == 'NeuralNet':
                clf = Sequential()
                clf.add(Dense(64, activation='relu', input_shape=(len(train_feature[0]),)))
                clf.add(Dense(1, activation='sigmoid'))
                
                clf.compile(loss='binary_crossentropy',
                            optimizer=keras.optimizers.Adadelta(),
                            metrics=['accuracy'])
                
                clf.fit(train_feature, train_label,
                          batch_size=128,
                          epochs=100,
                          verbose=0)
                predict = np.round(clf.predict(test_feature))
            else:
                clf.fit(train_feature, train_label)
                predict = clf.predict(test_feature)
            
            
            
            duration = time.time() - start
            
            print("Dim {0}, Time: {1}".format(dim, duration))
            print(confusion_matrix(test_label, predict, [0,1]))
            
            f1_values.append(f1_score(test_label, predict, [0,1]))
            durations.append(duration)
            dim += 100
            
        #データの保存、保存したデータはjoblib.loadでロードできる 
        joblib.dump(durations, '{0}_{1}_Note_durations.pkl'.format(clf_name, decomp_name))
        joblib.dump(f1_values, '{0}_{1}_Note_f1values.pkl'.format(clf_name, decomp_name))

C#で強化学習 その1 -Q-learning(Q学習)で簡単なゲームAIを作ってみる-

こんにちはLInk-Uの町屋敷です。

今回は強化学習をやっていきたいと思います。

主にQ-learningの具体的な実装の方法を書いて、Q-learning自体の証明とかには触れません。

強化学習は今までやってきたニューラルネットやSVMなどの学習方法と毛色が異なります。

何をやるかをざっくりいうと
ある問題を解きたいときにある状況になったときにこういうことをしたらこうなったという経験を蓄積して、
その経験を元に次に同じ状況になったとき最適な行動を選択するようにAIを訓練します。

Q-learningを行う準備

強化学習を行うアルゴリズムはたくさんあるのですが、今回はQ-learningという手法をC#で一から実装していきます。
いつものPythonだと一から実装すると重くてやってられなくなるかもしれないから回避。
C#には今回使わなかったけどUnityもあるしね。

早速実装して行きましょう。
今回のタスクはRPGでよくある1対1の戦闘で相手モンスターを倒すことです。
プレイヤーは状況に応じて攻撃や回復を行います。

強化学習ではプレイヤーに何回も戦闘を行ってその結果によってAIを賢くしていきます。
なので、Q-learningどうこうのまえに自動で戦闘を行うプログラムを先に書く必要があります。

まずプレイヤーや相手モンスターを扱う親クラスCharacterを作りましょう。

public abstract class Character
    {
        private string name;
        private int id;
        private int maxHp;
        private int maxMp;
        private int hp;
        private int mp;
        private Dictionary<int, Action> actions;
        private Dictionary<string, double> weeknesses;

        public Character()
        {
            this.Init();
        }

        private void Init()
        {
            this.name = this.SetName();
            this.id = this.SetId();
            this.maxHp = this.SetMaxHp();
            this.maxMp = this.SetMaxMp();
            this.hp = this.maxHp;
            this.mp = this.maxMp;
            this.actions = this.SetActions();
            this.weeknesses = this.SetWeekness();
        }

        private string SetName()
        {
            return this.GetType().Name;
        }

        public abstract int SetId();
        public abstract int SetMaxHp();
        public abstract int SetMaxMp();

        public void SetHp(int hp)
        {
            this.hp = hp;
        }

        public void SetMp(int mp)
        {
            this.mp = mp;
        }

        public abstract Dictionary<int, Action> SetActions();
        public abstract Dictionary<string, double> SetWeekness();

        public string GetName()
        {
            return this.name;
        }

        public int GetId()
        {
            return this.id;
        }

        public int GetHp()
        {
            return this.hp;
        }

        public int GetMp()
        {
            return this.mp;
        }


        public int GetMaxHp()
        {
            return this.maxHp;
        }

        public int GetMaxMp()
        {
            return this.maxMp;
        }

        public Dictionary<int, Action> GetActions()
        {
            return this.actions;
        }

        public Dictionary<string, double> GetWeekness()
        {
            return this.weeknesses;
        }

        public void RestoreHp()
        {
            this.hp = this.maxHp;
        }

        public void RestoreMp()
        {
            this.mp = this.maxMp;
        }

    }

攻撃力とか防御力とかはなくて、HPとMPだけです、ダメージ量とかは技で固定で、攻撃の属性に対する相性をweeknessesで設定するモデルです。

カードゲームによくある仕組み。

次に、キャラクターごとにデータを作っていきます。

public class Player : Character
    {
        public Player() : base()
        {

        }

        public override int SetId()
        {
            return 0;
        }

        public override int SetMaxHp()
        {
            return 100;
        }

        public override int SetMaxMp()
        {
            return 100;
        }

        public override Dictionary<int, Action> SetActions()
        {
            var actions = new Dictionary<int, Action>();
            actions.Add(ActionTable.normalAttack.id, ActionTable.normalAttack);
            actions.Add(ActionTable.magicAttack.id, ActionTable.magicAttack);
            actions.Add(ActionTable.heal.id, ActionTable.heal);
            return actions;
        }

        public override Dictionary<string, double> SetWeekness()
        {
            var weeknesses = new Dictionary<string, double>();
            weeknesses.Add("physical", 1.0);
            weeknesses.Add("magic", 1.0);
            return weeknesses;
        }
    }

SetWeeknessはさっき書いた属性に対する耐性です。今回属性は「物理」と「魔法」があります。

SetActions関数でそのキャラクターが選択できる行動を定義します。

CharacterクラスとPlayerクラスのように親クラスとしてActionクラスを作り、

行動の内容はActionTableクラスに定義されています。

    public class Action
    {
        public int id;
        public int hpDamage;
        public int hpHeal;
        public int mpCost;
        public string attribute;

        public Action(int id, int hpDamage, int hpHeal, int mpCost, string attribute, ref int count)
        {
            this.id = id;
            this.hpDamage = hpDamage;
            this.hpHeal = hpHeal;
            this.mpCost = mpCost;
            this.attribute = attribute;
            ++count;
        }
    }

    static class ActionTable
    {
        public static int count;
        public static Action normalAttack;
        public static Action magicAttack;
        public static Action heal;
        public static Action strongAttack;

        static ActionTable()
        {
            count = 0;
            normalAttack = new Action(count, 10, 0, 0, "physical", ref count);
            magicAttack = new Action(count, 15, 0, 10, "magic", ref count);
            heal = new Action(count, 0, 40, 10, "magic", ref count);
            strongAttack = new Action(count, 30, 0, 0, "physical", ref count);
        }
    }

Playerクラスでプレイヤーの情報を設定しました。

同じことを敵モンスターでも行います。今回はゴブリン、ウィッチ、グリズリーの3種類の敵を作りました。

それぞれの特徴は後で書きます。

さて、これでデータは揃ったのであとは戦闘部分を書けばいいのですが、Q-learningの説明を先にしてから説明します。

Q-learningを実装する

強化学習は、ある状況になったときにこういうことをしたらこうなったという経験を蓄積して、
その経験を元に次に同じ状況になったとき最適な行動を選択するようにAIを訓練するものでした。

よってまずどんな状況を考慮するかを決めてあげなくてはなりません。

RPGの一対一の戦闘で考えられうる状況のうち今回は以下の4つを考えます。

  1. 何と戦っているか
  2. 自分の残りHPはどのくらいか
  3. 自分の残りMPはどのくらいか
  4. 戦闘に勝利、敗北した

1は今回ゴブリン、ウィッチ、グリズリーの3パターンです。

残りHPや残りMPは単なる数字なのでその値一つ一つを異なる状況と考えてしまうと

連続値を扱えないQ-learningだと状況の数がかなり増えて計算に時間がかかってしまうので、ざっくり離散化します。

HPでは、残りHPが半分以上、残りHPが1/4以上、残りHPが1/4未満の3パターン。

MPでは、残りMPが半分以上ある、残りMPが半分未満だが0ではない、残りMP0の3パターンです。

1,2,3が戦闘中に考えられる状況です。

1,2,3それぞれ3パターンずつあって4が2パターンあるので今回考えられる状況は3*3*3+2=29通り存在します。

それぞれの状況でプレイヤーは選択できる行動を行います。つまりMPが切れていない状況では、

Playerクラスの関数SetActionsで追加した、通常攻撃、魔法攻撃、HP回復のうちどれかを実行することになります。

通常攻撃は相手に10ダメージ、魔法攻撃はMP10を使って相手に15ダメージ、回復はMP10を使って自分のHPを40回復します。

回復だけやけに数値が大きいのは、はじめ10でやってたんですけどQ-learningの結果全く使われない産廃と化していたので上げました。

次に報酬について説明します。

強化学習では、行った行動がいい行動だったのか悪い行動だったのかを判別する基準として報酬を用います。

つまり、行動の結果に点数を付けてあげます。

今回はRPGの戦闘なので、敵を倒したら1000点、負けたら-1000点といった超単純なものです。

HP,Mpがたくさんあって敵がゴブリンのときに通常攻撃して敵を倒せたらその組み合わせにプラスの評価がされるというわけです。

しかし、毎回行動を行った直後に効果が現れるとは限らないので、ある行動の評価を行うときはある程度未来までみてその間に得た報酬を参考にします。

つまり、HP,Mpがたくさんあって敵がゴブリンのときに通常攻撃を2回行ってから魔法攻撃をして敵を倒した場合、最後の魔法攻撃だけにいい評価を与えるのではなく、最初の通常攻撃にも意味があるだろうという考えです。

ある程度未来までみてその間に得た報酬のことを収益と言ってその期待値を行動価値Qといいます。

このへんは本や他のサイトのほうが詳しいので、それを見ることをあ勧めします。

Qは収益の期待値を表すので、ある状況、行動を行ったときのQの値がわかっていれば、選択できる行動の中でQの一番大きい行動を選んでおけば良いということがわかります。

そしてこのQを経験から計算するのがQ-learningです。

C#ではこんな感じ。

        public void updateQ(int situationNo, int nextSituation, int actionNo, double reward, List<int> unselectableActions = null, List<int> nextUnselectableActions = null)
        {
            int maxIndex = -1;
            double maxQ = -10000000;

            this.qValues[situationNo, actionNo] = (1 - this.alpha) * this.qValues[situationNo, actionNo]
                + this.alpha * (reward + this.gamma * serachMaxAndArgmax(nextSituation, ref maxIndex, ref maxQ, nextUnselectableActions));
        }

Qの値は状況、行動ごとに変わるので1度に更新するのは配列 this.qValues[situationNo, actionNo]の値1つです。

1回の戦闘で報酬が得られるタイミングは1回だけなので何回も何回も戦闘を行います。

1回の戦闘のことを1エピソードと呼びます。

        public override int MainProcess(int nEpisodes)
        {
            this.player = new Player();
            this.SetEneyList();

            var q = new QLearning(
                this.GetSituationSize(),
                this.GetActionSize(),
                0
            );
            int situationNo;
            int nextSituationNo;
            int e = 0;
            while (e < nEpisodes)
            {
                //this.enemy = this.SelectRandomEnemy();
                this.enemy = this.enemyList[e % 3];
                player.RestoreHp();
                player.RestoreMp();
                enemy.RestoreHp();
                enemy.RestoreMp();
                do
                {
                    situationNo = this.GetSituationNo();
                    nextSituationNo = -1;
                    var unselectable = this.GetUnselectableActions(situationNo);
                    int actionNo = q.SelectActionByEGreedy(0.05, situationNo, unselectable);

                    double reward = this.CalcActionResult(situationNo, ref nextSituationNo, actionNo);
                    var nextUnselectable = this.GetUnselectableActions(nextSituationNo);
                    q.updateQ(situationNo, nextSituationNo, actionNo, reward, unselectable, nextUnselectable);

                } while (nextSituationNo != ENEMY_DEATH && nextSituationNo != PLAYER_DEATH);

				if (e % 1000 == 0)
				{
					Console.WriteLine();
                    Console.Write(String.Format("Progress {0:f4}%", (double) 100 * e / nEpisodes));
					Console.WriteLine();
					this.PrintQValuesWithParams(q);
				}
				e++;
            }

            Console.WriteLine();
            this.PrintQValuesWithParams(q);

            return 0;
        }

while (e < nEpisodes)の中が1回の戦闘で、do-while文の中がプレイヤーと敵の攻撃の1ターンになります。

this.GetSituationNo()で現在の状況を確認し、q.SelectActionByEGreedyで状況とQをもとに行動を選択しています。

ここで、常に一番大きなQをもつ行動を選択するようにしてしまうと、たまたま良い結果が出たものを集中的に行って本当に良い行動を見落としてしまう可能性があるため小さい確率で一番大きなQをもつ行動以外を選択するようにします。

CalcActionResultで戦闘の処理と報酬の付与を行っています。

        public override double CalcActionResult(int situationNo, ref int nextSituation, int actionNo)
        {

            if (situationNo == PLAYER_DEATH)
            {
                nextSituation = PLAYER_DEATH;
                return PLAYER_DEATH_REWARD;
            }
            if (situationNo == ENEMY_DEATH)
            {
                nextSituation = ENEMY_DEATH;
                return ENEMY_DEATH_REWARD;
            }

            int playerHp = this.player.GetHp();
            int playerMp = this.player.GetMp();
			int playerMpOrg = playerMp;
            int enemyHp = this.enemy.GetHp();
            int enemyMp = this.enemy.GetMp();

            Action playerAction = this.player.GetActions()[actionNo];
            string playerAttackAttribute = playerAction.attribute;
            enemyHp -= (int)(playerAction.hpDamage * this.enemy.GetWeekness()[playerAttackAttribute]);
            playerHp += this.player.GetActions()[actionNo].hpHeal;
            if (playerHp > this.player.GetMaxHp())
                playerHp = this.player.GetMaxHp();
            playerMp -= this.player.GetActions()[actionNo].mpCost;
            if (enemyHp <= 0)
            {
                nextSituation = ENEMY_DEATH;
                return ENEMY_DEATH_REWARD;
            }
            Action enemyAction = this.SelectRandomAction(enemy);
            string enemyAttackAttribute = enemyAction.attribute;
            playerHp -= (int)(enemyAction.hpDamage * this.player.GetWeekness()[enemyAttackAttribute]);
            enemyHp += enemyAction.hpHeal;
            if (enemyHp > this.enemy.GetMaxHp())
                enemyHp = this.enemy.GetMaxHp();
            enemyMp -= enemyAction.mpCost;
            if (playerHp <= 0)
            {
                nextSituation = PLAYER_DEATH;
                return PLAYER_DEATH_REWARD;
            }
            this.player.SetHp(playerHp);
            this.player.SetMp(playerMp);
            this.enemy.SetHp(enemyHp);
            this.enemy.SetMp(enemyMp);


            nextSituation = this.GetSituationNo();

			return (playerMp - playerMpOrg) * MP_CONSUMPTION_REWARD_RATE;
        }

プレイヤーとモンスターのHP,MPを足したり引いたりしてHPが0になったら死亡判定してるだけです。

どちらかが死亡するとエピソードが終了します。

エピソードを行う数は予め決まっていてそれをすべて実行し終えると終了です。

Q-learningの実行結果

各モンスターに対して100万回戦闘しました。

普通にやったら何ヶ月かかるのかわからない処理ですがPCさんなら6分17秒でやってくれます。

対戦相手がゴブリンのときの結果がこれ。

ゴブリンは通常攻撃、魔法攻撃ともに等倍のダメージを受け、毎ターン通常攻撃しか行わないただの雑魚です。

少し見にくいが2行で1つのデータで、1行目は左から「対戦相手」、「HP状態」、[MP状態]を表す。、「HP状態」は0のとき半分以上1のとき1/4以上、2のときが1/4以下を表し、「MP状態」は0のとき半分以上1のとき0以上、2のときが0です。2行目は各行動の行動価値を表していて左から通常攻撃、魔法攻撃、回復です。

だから、上から2行目までは対戦相手がゴブリンのときでHP,MPが十分ある時は3つの行動のうち一番値の大きい魔法攻撃をするのが良いことになります。

BlogRainForest.Gobrin 2 0の行をみるとHPがかなり少なくてMPが十分ある時は回復するのが良いということがわかります。

次にウイッチの場合、ウイッチは裏で通常攻撃を効きやすく(2倍)、魔法攻撃を効きにくく(半分)しています。

結果では通常攻撃の行動価値が魔法攻撃の行動価値よりも常に大きくなっているのでウィッチ相手には通常攻撃と回復しか選択しないようなAIを作ることが出来ました。

まとめ

Q-learningを用いることで、簡単なゲームAIを作ることが出来ました。

Q-learningでは状態は離散値でしたが状況や行動を連続した値で表して計算を行う方法も存在するので、機会があったら解説しようと思います。

使用したプログラム

using System;
using System.Collections.Generic;


namespace BlogRainForest
{
    class Program
    {
        static void Main(string[] args)
		{
			var sw = new System.Diagnostics.Stopwatch();
            sw.Start();

            //var ql = new QLearning();
            var rb = new RpgBattle();
            rb.MainProcess(3000000);
			sw.Stop();
			TimeSpan ts = sw.Elapsed;
            Console.WriteLine($" {sw.ElapsedMilliseconds}ms");
            Console.Write("nEndn");
            Console.Read();
        }
    }


    public class QLearning
    {
        public double[,] qValues;
        public double alpha;
        public double gamma;

        public QLearning(int sSize, int aSize, int fillValue, double alpha = 0.01, double gamma = 0.8)
        {
            this.alpha = alpha;
            this.gamma = gamma;
            this.qValues = new double[sSize, aSize];
            for (int i = 0; i < sSize; i++)
            {
                for (int j = 0; j < aSize; j++)
                {
                    this.qValues[i, j] = fillValue;
                }
            }
        }

        public void updateQ(int situationNo, int nextSituation, int actionNo, double reward, List<int> unselectableActions = null, List<int> nextUnselectableActions = null)
        {
            int maxIndex = -1;
            double maxQ = -10000000;

            this.qValues[situationNo, actionNo] = (1 - this.alpha) * this.qValues[situationNo, actionNo]
                + this.alpha * (reward + this.gamma * serachMaxAndArgmax(nextSituation, ref maxIndex, ref maxQ, nextUnselectableActions));
        }

        public int SelectActionByGreedy(int situationNo, List<int> unselectableActions = null)
        {
            unselectableActions = unselectableActions ?? new List<int>();
            int maxIndex = -1;
            double maxQ = -10000000;
            this.serachMaxAndArgmax(situationNo, ref maxIndex, ref maxQ, unselectableActions);
            return maxIndex;
        }

        public int SelectActionByEGreedy(double epsilon, int situationNo, List<int> unselectableActions = null)
        {
            Random r = new Random();
            if (r.NextDouble() < epsilon)
            {
				int action = -1;
				do
				{
					action = r.Next(this.qValues.GetLength(1));
				} while (unselectableActions.Contains(action));
				return action;
            }
            else
            {
                return this.SelectActionByGreedy(situationNo, unselectableActions);
            }
        }

        private double serachMaxAndArgmax(int situationNo, ref int maxIndex, ref double maxQ, List<int> unselectableActions = null)
		{
            unselectableActions = unselectableActions ?? new List<int>();
            for (int j = 0; j < this.qValues.GetLength(1); j++)
            {
				
				if (unselectableActions.Contains(j))
				{
					continue;
				}
                if (this.qValues[situationNo, j] > maxQ)
                {
                    maxIndex = j;
                    maxQ = this.qValues[situationNo, j];
                }
            }
            
            return maxQ;
        }

        public void PrintQValues()
        {
            var rowCount = this.qValues.GetLength(0);
            var colCount = this.qValues.GetLength(1);
            for (int row = 0; row < rowCount; row++)
            {
                for (int col = 0; col < colCount; col++)
                    Console.Write(String.Format("{0}t", this.qValues[row, col]));
                Console.WriteLine();
            }
        }
    }


    public abstract class Task
    {
        public abstract int MainProcess(int nEpisodes);
        public abstract int GetSituationSize();
        public abstract int GetActionSize();
        public abstract List<int> GetUnselectableActions(int situation);
        public abstract double CalcActionResult(int situationNo, ref int nextSituation, int actionNo);
    }

    public class RpgBattle : Task
    {
        public const int PLAYER_DEATH = 27;
        public const int ENEMY_DEATH = 28;

        public const double PLAYER_DEATH_REWARD = -1000;
        public const double ENEMY_DEATH_REWARD = 1000;
		public const double MP_CONSUMPTION_REWARD_RATE = 0;

        public Character player;
        public Character enemy;
        public List<Character> enemyList;


        public override int MainProcess(int nEpisodes)
        {
            this.player = new Player();
            this.SetEneyList();

            var q = new QLearning(
                this.GetSituationSize(),
                this.GetActionSize(),
                0
            );
            int situationNo;
            int nextSituationNo;
            int e = 0;
            while (e < nEpisodes)
            {
                //this.enemy = this.SelectRandomEnemy();
                this.enemy = this.enemyList[e % 3];
                player.RestoreHp();
                player.RestoreMp();
                enemy.RestoreHp();
                enemy.RestoreMp();
                do
                {
                    situationNo = this.GetSituationNo();
                    nextSituationNo = -1;
                    var unselectable = this.GetUnselectableActions(situationNo);
                    int actionNo = q.SelectActionByEGreedy(0.05, situationNo, unselectable);

                    double reward = this.CalcActionResult(situationNo, ref nextSituationNo, actionNo);
                    var nextUnselectable = this.GetUnselectableActions(nextSituationNo);
                    q.updateQ(situationNo, nextSituationNo, actionNo, reward, unselectable, nextUnselectable);

                } while (nextSituationNo != ENEMY_DEATH && nextSituationNo != PLAYER_DEATH);

				if (e % 1000 == 0)
				{
					Console.WriteLine();
                    Console.Write(String.Format("Progress {0:f4}%", (double) 100 * e / nEpisodes));
					Console.WriteLine();
					this.PrintQValuesWithParams(q);
				}
				e++;
            }

            Console.WriteLine();
            this.PrintQValuesWithParams(q);

            return 0;
        }

        public void PrintQValuesWithParams(QLearning q)
        {
            var rowCount = q.qValues.GetLength(0);
            var colCount = q.qValues.GetLength(1);

            int charaId = -1;
            int hpSituation = -1;
            int mpSituation = -1;

            for (int row = 0; row < rowCount; row++)
            {
                if (row >= PLAYER_DEATH)
                    continue;
                this.GetParamsBySituationIndex(row, ref charaId, ref hpSituation, ref mpSituation);
                Console.Write(String.Format("{0}t{1}t{2}", this.enemyList[charaId - 1], hpSituation, mpSituation));
                Console.WriteLine();
                for (int col = 0; col < colCount; col++)
                    Console.Write(String.Format("{0}t", q.qValues[row, col]));
                Console.WriteLine();
            }
        }

        public override int GetSituationSize()
        {
            int enemyKindCount = 3;
            int hpSituationCount = 3;
            int mpSituationCount = 3;
            //自分死亡(PLAYER_DEATH)と相手死亡(ENEMY_DEATH)状態も足す
            return enemyKindCount * hpSituationCount * mpSituationCount + 2;
        }

        public int GetSituationNo()
        {
            return 9 * (this.enemy.GetId() - 1) + 3 * this.GetHpSituation() + this.GetMpSituation();
        }

        private int GetSituationIndexByParams(int charactterId, int hpSituation, int mpSituation)
        {
            return 9 * (charactterId - 1) + 3 * hpSituation + mpSituation;
        }

        private void GetParamsBySituationIndex(int situationId, ref int characterId, ref int hpSituation, ref int mpSituation)
        {
            characterId = situationId / 9;
            hpSituation = (situationId - 9 * characterId) / 3;
            mpSituation = situationId - 9 * characterId - 3 * hpSituation;
            ++characterId;
        }

        public override int GetActionSize()
        {
            return this.player.GetActions().Count;
        }

        public override List<int> GetUnselectableActions(int situation)
        {
            if (situation == PLAYER_DEATH)
				return new List<int>();
            if (situation == ENEMY_DEATH)
				return new List<int>();
            List<int> unselectableActions = new List<int>();

            foreach (KeyValuePair<int, Action> a in this.player.GetActions())
            {
                if (a.Value.mpCost > this.player.GetMp())
                    unselectableActions.Add(a.Value.id);
            }

            return unselectableActions;
        }

        public override double CalcActionResult(int situationNo, ref int nextSituation, int actionNo)
        {

            if (situationNo == PLAYER_DEATH)
            {
                nextSituation = PLAYER_DEATH;
                return PLAYER_DEATH_REWARD;
            }
            if (situationNo == ENEMY_DEATH)
            {
                nextSituation = ENEMY_DEATH;
                return ENEMY_DEATH_REWARD;
            }

            int playerHp = this.player.GetHp();
            int playerMp = this.player.GetMp();
			int playerMpOrg = playerMp;
            int enemyHp = this.enemy.GetHp();
            int enemyMp = this.enemy.GetMp();

            Action playerAction = this.player.GetActions()[actionNo];
            string playerAttackAttribute = playerAction.attribute;
            enemyHp -= (int)(playerAction.hpDamage * this.enemy.GetWeekness()[playerAttackAttribute]);
            playerHp += this.player.GetActions()[actionNo].hpHeal;
            if (playerHp > this.player.GetMaxHp())
                playerHp = this.player.GetMaxHp();
            playerMp -= this.player.GetActions()[actionNo].mpCost;
            if (enemyHp <= 0)
            {
                nextSituation = ENEMY_DEATH;
                return ENEMY_DEATH_REWARD;
            }
            Action enemyAction = this.SelectRandomAction(enemy);
            string enemyAttackAttribute = enemyAction.attribute;
            playerHp -= (int)(enemyAction.hpDamage * this.player.GetWeekness()[enemyAttackAttribute]);
            enemyHp += enemyAction.hpHeal;
            if (enemyHp > this.enemy.GetMaxHp())
                enemyHp = this.enemy.GetMaxHp();
            enemyMp -= enemyAction.mpCost;
            if (playerHp <= 0)
            {
                nextSituation = PLAYER_DEATH;
                return PLAYER_DEATH_REWARD;
            }
            this.player.SetHp(playerHp);
            this.player.SetMp(playerMp);
            this.enemy.SetHp(enemyHp);
            this.enemy.SetMp(enemyMp);


            nextSituation = this.GetSituationNo();

			return (playerMp - playerMpOrg) * MP_CONSUMPTION_REWARD_RATE;
        }

        public void SetEneyList()
        {
            this.enemyList = new List<Character>();
            this.enemyList.Add(new Goblin());
            this.enemyList.Add(new Witch());
            this.enemyList.Add(new Grizzly());
        }

        private Character SelectRandomEnemy()
        {
            Random rnd = new Random();
            int ri = rnd.Next(this.enemyList.Count);
            return this.enemyList[ri];
        }

        private Action SelectRandomAction(Character character)
        {
            Random rnd = new Random();
            List<int> KeyList = new List<int>(character.GetActions().Keys);
            int ri = rnd.Next(KeyList.Count);
            return character.GetActions()[KeyList[ri]];
        }

        private int GetHpSituation()
        {
            if (this.player.GetMaxHp() / 2 < this.player.GetHp())
            {
                return 0;
            }
            else if (this.player.GetMaxHp() / 4 < this.player.GetHp())
            {
                return 1;
            }
            else
            {
                return 2;
            }
        }

        private int GetMpSituation()
        {
            if (this.player.GetMaxMp() / 2 < this.player.GetMp())
            {
                return 0;
            }
            else if (10 <= this.player.GetMp())
            //else if (this.player.GetMaxMp() / 4 < this.player.GetMp())
            {
                return 1;
            }
            else
            {
                return 2;
            }
        }
    }


    public abstract class Character
    {
        private string name;
        private int id;
        private int maxHp;
        private int maxMp;
        private int hp;
        private int mp;
        private Dictionary<int, Action> actions;
        private Dictionary<string, double> weeknesses;

        public Character()
        {
            this.Init();
        }

        private void Init()
        {
            this.name = this.SetName();
            this.id = this.SetId();
            this.maxHp = this.SetMaxHp();
            this.maxMp = this.SetMaxMp();
            this.hp = this.maxHp;
            this.mp = this.maxMp;
            this.actions = this.SetActions();
            this.weeknesses = this.SetWeekness();
        }

        private string SetName()
        {
            return this.GetType().Name;
        }

        public abstract int SetId();
        public abstract int SetMaxHp();
        public abstract int SetMaxMp();

        public void SetHp(int hp)
        {
            this.hp = hp;
        }

        public void SetMp(int mp)
        {
            this.mp = mp;
        }

        public abstract Dictionary<int, Action> SetActions();
        public abstract Dictionary<string, double> SetWeekness();

        public string GetName()
        {
            return this.name;
        }

        public int GetId()
        {
            return this.id;
        }

        public int GetHp()
        {
            return this.hp;
        }

        public int GetMp()
        {
            return this.mp;
        }


        public int GetMaxHp()
        {
            return this.maxHp;
        }

        public int GetMaxMp()
        {
            return this.maxMp;
        }

        public Dictionary<int, Action> GetActions()
        {
            return this.actions;
        }

        public Dictionary<string, double> GetWeekness()
        {
            return this.weeknesses;
        }

        public void RestoreHp()
        {
            this.hp = this.maxHp;
        }

        public void RestoreMp()
        {
            this.mp = this.maxMp;
        }

    }

    public class Player : Character
    {
        public Player() : base()
        {

        }

        public override int SetId()
        {
            return 0;
        }

        public override int SetMaxHp()
        {
            return 100;
        }

        public override int SetMaxMp()
        {
            return 100;
        }

        public override Dictionary<int, Action> SetActions()
        {
            var actions = new Dictionary<int, Action>();
            actions.Add(ActionTable.normalAttack.id, ActionTable.normalAttack);
            actions.Add(ActionTable.magicAttack.id, ActionTable.magicAttack);
            actions.Add(ActionTable.heal.id, ActionTable.heal);
            return actions;
        }

        public override Dictionary<string, double> SetWeekness()
        {
            var weeknesses = new Dictionary<string, double>();
            weeknesses.Add("physical", 1.0);
            weeknesses.Add("magic", 1.0);
            return weeknesses;
        }
    }

    public class Goblin : Character
    {
        public Goblin() : base()
        {

        }

        public override int SetId()
        {
            return 1;
        }

        public override int SetMaxHp()
        {
            return 140;
        }

        public override int SetMaxMp()
        {
            return 0;
        }

        public override Dictionary<int, Action> SetActions()
        {
            var actions = new Dictionary<int, Action>();
            actions.Add(ActionTable.normalAttack.id, ActionTable.normalAttack);
            return actions;
        }

        public override Dictionary<string, double> SetWeekness()
        {
            var weeknesses = new Dictionary<string, double>();
            weeknesses.Add("physical", 1.0);
            weeknesses.Add("magic", 1.0);
            return weeknesses;
        }
    }

    public class Witch : Character
    {
        public Witch() : base()
        {

        }

        public override int SetId()
        {
            return 2;
        }

        public override int SetMaxHp()
        {
            return 100;
        }

        public override int SetMaxMp()
        {
            return 200;
        }

        public override Dictionary<int, Action> SetActions()
        {
            var actions = new Dictionary<int, Action>();
            actions.Add(ActionTable.magicAttack.id, ActionTable.magicAttack);
            actions.Add(ActionTable.heal.id, ActionTable.heal);
            return actions;
        }

        public override Dictionary<string, double> SetWeekness()
        {
            var weeknesses = new Dictionary<string, double>();
            weeknesses.Add("physical", 2.0);
            weeknesses.Add("magic", 0.5);
            return weeknesses;
        }
    }

    public class Grizzly : Character
    {
        public Grizzly() : base()
        {

        }

        public override int SetId()
        {
            return 3;
        }

        public override int SetMaxHp()
        {
            return 240;
        }

        public override int SetMaxMp()
        {
            return 0;
        }

        public override Dictionary<int, Action> SetActions()
        {
            var actions = new Dictionary<int, Action>();
            actions.Add(ActionTable.normalAttack.id, ActionTable.normalAttack);
            actions.Add(ActionTable.strongAttack.id, ActionTable.strongAttack);
            return actions;
        }

        public override Dictionary<string, double> SetWeekness()
        {
            var weeknesses = new Dictionary<string, double>();
            weeknesses.Add("physical", 1.0);
            weeknesses.Add("magic", 2.0);
            return weeknesses;
        }
    }

    public class Action
    {
        public int id;
        public int hpDamage;
        public int hpHeal;
        public int mpCost;
        public string attribute;

        public Action(int id, int hpDamage, int hpHeal, int mpCost, string attribute, ref int count)
        {
            this.id = id;
            this.hpDamage = hpDamage;
            this.hpHeal = hpHeal;
            this.mpCost = mpCost;
            this.attribute = attribute;
            ++count;
        }
    }

    static class ActionTable
    {
        public static int count;
        public static Action normalAttack;
        public static Action magicAttack;
        public static Action heal;
        public static Action strongAttack;

        static ActionTable()
        {
            count = 0;
            normalAttack = new Action(count, 10, 0, 0, "physical", ref count);
            magicAttack = new Action(count, 15, 0, 10, "magic", ref count);
            heal = new Action(count, 0, 40, 10, "magic", ref count);
            strongAttack = new Action(count, 30, 0, 0, "physical", ref count);
        }
    }
}

データの打ち込みを回避するために手書き文字をスキャンして認識する

今回の目的

こんにちは。

Link-Uの町屋敷です。

今回は下のような手書き文字をパソコンにスキャンして認識するプログラムを作ろうと思います。

テストではどこに文字があるかの情報は全く与えず、文字の場所がどこにあるのかから機械に算出させます。(訓練データもいちいち座標与えるのがめんどくさかったのでそうなりましたが…)

普通のアルファベットと微妙に違うのがわかるでしょうか。

実はこれらは全部cに線を書き足してabcdeにしています。

何のためにこんなものの認識をしたいかというと、外部に出さないようなデータのアノテーション(ラベル付け)をするときに、項目とそれに対応するcが書かれた紙を用意して、項目に該当する部分だけcをaとかeなどに変更してあとAIに任せれば作業効率が上がるかなって。

自分でスキャンされた手書き文字認識を実装してみる

さっそく始めましょう。

画像中の物体を認識するための手法としてR-CNNを参考にします。

画像からの物体認識では大きく二つのパートがあります。画像の中から文字が書かれている領域の候補を選択するパートとその領域内に書かれた文字が何なのかを認識するパートです。

Kerasで一度にR-CNNをできるものもある(より新しいMask-RCNNですら存在する)ようですが、地道に領域をから候補を選択してCNNをするという方法でやります。

Selective Search

画像の中から文字が書かれている領域の候補を選択する方法としてR-CNNの元論文ではSelective searchが使われています。Selective searchにはpythonのパッケージが公開されているのでサンプルを改造して試してみましょう。入力画像は屋上から撮った写真を400*300にリサイズしたものです。

import selectivesearch
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

def search(img):
    #img = img[0:300,200:800]
    # perform selective search
    img_lbl, regions = selectivesearch.selective_search(
    #    img, scale=5, sigma=0.25, min_size=20)
        img, scale=5, sigma=5, min_size=30)
    candidates = set()
    for r in regions:
        #excluding same rectangle (with different segments)
        if r['rect'] in candidates:
            continue
        # excluding regions smaller than 2000 pixels
        if r['size'] < 30:
            continue
        # distorted rects
        x, y, w, h = r['rect']
        if h == 0 or w == 0:
            continue
        if w / h > 8 or h / w > 8:
            continue
        candidates.add(r['rect'])

    #draw rectangles on the original image
    fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(6, 6))
    ax.imshow(img)
    for x, y, w, h in candidates:
        rect = mpatches.Rectangle(
            (x, y), w, h, fill=False, edgecolor='red', linewidth=1)
        ax.add_patch(rect)
  
    plt.show()
    return candidates

def test_selective_search():
    im = imread("tower.png")
    search(im[:,:,0:3]) #アルファチャンネルを消す

if __name__ == "__main__":
    test_selective_search()
    learn()
    test()

こんな感じで領域の候補が出力されれば成功です。同じ物体に複数の領域が重なっていますがこれが正しい出力結果です。

Selective searchのパッケージが正しく動作していることが分かったので、さっそく訓練データを入れていきましょう。今回の訓練データは、フォントの大きさが違うcが書かれた紙を4枚印刷し、線を付け加えて作ったa,b,d,eの紙をスキャナーで取り込んで縮小したものです。(縮小したものを使わないとSelective Searchの処理が重い)

実際のものがこれです(aのデータ)、折れとか汚れ(スキャン時にできた)とか雑音はいりまくりです。フォントサイズが変わってもAIが対応できるように微妙にフォントの大きさを変えてますが意味があるかは対称を取ってないのでわかりません。

さてこのままさっきのSelective searchに入れてみましょう。

どえらい結果になります。雑音のせいなので取り除きましょう。

今回は背景が白で黒い文字のみの検出を目指しているのでべたに2値化。

    gray = rgb2gray(im)       # グレイスケールに変換
    whiteRegion = gray > 0.8  # 白部分の雑音除去
    gray[whiteRegion] = 1
    whiteRegion = gray <= 0.8  # 黒部分の雑音除去
    gray[whiteRegion] = 0
    im = gray2rgb(gray)

さてこのままさっきのSelective searchに入れてみましょう。<img class=”alignnone wp-image-337 size-full” src=”https://tech.link-u.co.jp/wp-content/
また、Selective searchの設定も変えましょう。

1文字の認識には長方形の領域はいらないので消去、

また上のほうの小さい文字が認識できてないのでomegaを変更して感度も上げ、画像全体が領域として選択されているのでこれも弾きます。

img_lbl, regions = selectivesearch.selective_search(
img, scale=5, sigma=0.25, min_size=20)
#img, scale=5, sigma=8, min_size=30)
candidates = set()
for r in regions:
#excluding same rectangle (with different segments)
if r['rect'] in candidates:
continue
#excluding regions smaller than 2000 pixels
if r['size'] < 20 or r['size'] > 3000:
continue
# distorted rects
x, y, w, h = r['rect']
if h == 0 or w == 0:
continue
if w / h > 2 or h / w > 2:
continue
candidates.add(r['rect'])

まともになりました。

しかしよく見て見ると一つの文字に複数の領域が重なっています。

もとのR-CNNの論文ではgreedy non-maximum suppressionなどを駆使して除去していますが。今回は汚れなどを除いて背景がなく、文字が重なることもなく、また各文字の大きさも一定なので、単純にある領域に近い領域か存在した場合、領域の中心座標だけ残して領域は全部消し、その点から上下左右16pixを新しい領域にしました。

def get_candidate(img_path):
    def is_checked(checked, test_point,r):
        for c in checked:
            if np.linalg.norm(c - test_point) < r:
                return True
        return False

    im = imread(img_path)
    gray = rgb2gray(im)       # グレイスケールに変換
    whiteRegion = gray > 0.8  # 白部分の雑音除去
    gray[whiteRegion] = 1
    whiteRegion = gray <= 0.8  # 黒部分の雑音除去
    gray[whiteRegion] = 0
    # = median(gray, disk(1))
    #gray = denoise_wavelet(gray) #ごましお雑音除去
    im = gray2rgb(gray)
    cand = search(im)
    print(np.shape(im))
    g_img = np.asarray(gray)
    checked_points = []
    images = []
    ren = 16
    for c in cand:
        xc = int(c[0] + c[2]/2)
        yc = int(c[1] + c[3]/2)
        if len(checked_points) is not 0:
            if is_checked(checked_points, np.array([xc,yc]), 10):
                continue
        image = resize(g_img[yc-ren:yc+ren,xc-ren:xc+ren], (28,28))
        images.append(image)
        #X = np.array([image])

        #plt.imshow(X[0], cmap='gray')
        checked_points.append(np.array([xc,yc]))
        #plt.show()
    return images, checked_points

これをa,b,c,d,eの5枚に行います。これで教師データとして画像データを取り出すことができます。

しかし、ラベルデータがないので

CNN

作ったデータセットでCNNの識別器を作ります。

スクリプトは以下の通り。

def learn():
    import glob
    img_pathes = glob.glob("..\\..\\Documents\\dataset\\abcde\\learn\\scan\\*.*")

    x = np.array([])
    y = np.array([])
    for ip in img_pathes:
        print(ip)
        if "a_" in ip:
            i = 0
        elif "b_" in ip:
            i = 1
        elif "c_" in ip:
            i = 2
        elif "d_" in ip:
            i = 3
        elif "e_" in ip:
            i = 4
        else:
            raise("Unexpected File is Loaded")


        imgs, _ =  get_candidate(ip)
        nimgs = np.array(imgs)
        x = np.append(x, nimgs)
        print(len(x))
        y = np.append(y, np.array(len(imgs) * [i]))
        print(len(y))
    x = np.reshape(x, (-1,28,28))
    seed = 123
    (X_train, X_valid, Y_train, Y_valid) = train_test_split(x, y, test_size=0.15, random_state=seed)
    learn_cnn(X_train, X_valid, Y_train, Y_valid)
    py = load_predict(x)
    for a,b,c in zip(x,y,py):
        if not b == np.argmax(c):
            print((b,c))
            plt.imshow(a, cmap='gray')
            plt.show()

def learn_cnn(X_train, X_test, Y_train, Y_test):
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
    X_test  = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

    #X_train = X_train / 255
    #X_test  = X_test / 255
    # one hot encode outputs
    Y_train = np_utils.to_categorical(Y_train)
    Y_test  = np_utils.to_categorical(Y_test)

    num_classes = Y_test.shape[1]


    # create model
    model = Sequential()
    model.add(Conv2D(64, (5, 5), input_shape=(28, 28, 1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(4, 4)))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(5, activation='softmax'))

    # Compile model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=500, batch_size=200, verbose=2)

    # Final evaluation of the model
    scores = model.evaluate(X_test,Y_test, verbose=0)
    print("CNN Error: %.2f%%" % (100-scores[1]*100))

    model.save('weights.model')

Kerasのサンプルに畳み込み層とプーリング層を追加したものです。多分これが最善ではないです。

パラメータを変えるなど各種改造をするときは、testデータを使わずtrainデータを分割して評価して最善のものを目指します。

このような結果が出力されます。テストデータのうち2パーセント間違えたことがわかります。

また、このスクリプトは、学習後に間違えたデータを表示します。

この2つのデータを間違えたようです。(実際は左がa右がd)

試行錯誤繰り返し、満足したところでtestデータを使って認識結果を出力してみましょう。

教師データを分割していましたが、もったいないのですべてを教師データとしてもう一度CNNを学習させます。

learn_cnn(x, X_valid, y, Y_valid)

テストデータは最初に貼ったデータと

このスパースなデータです。

いつものGPUのideapadS720のスピードの差はこちら

GPUのほうが6倍くらい速いです。

一回大量データかつモデルが複雑な場合もやらねばなあ。

で、文字の識別結果は以下の通り。

画像の中の?はCNNの出力を見てAIに自信がなさそうな時に付けています。自信度がなぜわかるか具体的に書くと今回CNNの出力は5クラスの識別問題なので5次元のベクトルで、それぞれが0から1の値を持っています。5つの次元は[a,b,c,d,e]を表します。

例えば、わかりやすいaの画像が入力された場合5次元のベクトルは[1, 0 , 0, 0, 0]に近い値を出力します。少しdとまぎらわしいaが入力された場合[0.8, 0, 0, 0.2, 0]のような値を出します。このベクトルの数値を見ることによってAIの自信度がわかるということです。

1つめは4ミスでしたが2つ目はノーミスでした。もう少し教師データを増やせば2つぐらいは減らせると思います。丁寧に書けば使えるかなといった感じです。

まとめ

今回はスキャンした手書き文字を認識するプログラムを作りました。

字をきれいに書けば使える程度のものは作れたと思います。

プログラム全文

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras import backend as K
from keras.utils import np_utils
from sklearn.model_selection import train_test_split
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches
from decimal import *
import skimage
from skimage.filters.rank import median
from skimage.morphology import disk
from skimage.transform import resize
from skimage.restoration import denoise_wavelet
from skimage.io import imread
from skimage.color import rgb2gray, gray2rgb
from PIL import Image, ImageDraw, ImageFont
import selectivesearch
import time

def load_predict(X):
    X  = X.reshape(X.shape[0], 28, 28, 1).astype('float32')
    #X = X / 255
    # create model
    model = Sequential()
    model.add(Conv2D(64, (5, 5), input_shape=(28, 28, 1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(4, 4)))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(5, activation='softmax'))

    # Compile model
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.load_weights('weights.model')
    return model.predict(X)

def search(img):
    #img = img[0:300,200:800]
    # perform selective search
    img_lbl, regions = selectivesearch.selective_search(
        img, scale=5, sigma=0.25, min_size=20)
        #img, scale=5, sigma=8, min_size=30)
    candidates = set()
    for r in regions:
        #excluding same rectangle (with different segments)
        if r['rect'] in candidates:
            continue
        # excluding regions smaller than 2000 pixels
        if r['size'] < 20 or r['size'] > 3000:
            continue
        # distorted rects
        x, y, w, h = r['rect']
        if h == 0 or w == 0:
            continue
        if w / h > 2 or h / w > 2:
            continue
        candidates.add(r['rect'])

    #draw rectangles on the original image
    #fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(6, 6))
    #ax.imshow(img)
    #===========================================================================
    # for x, y, w, h in candidates:
    #     rect = mpatches.Rectangle(
    #         (x, y), w, h, fill=False, edgecolor='red', linewidth=1)
    #     ax.add_patch(rect)
    #===========================================================================

    #plt.show()
    return candidates

def test_selective_search():
    im = imread("a_scan.png")
    search(im[:,:,0:3])

def learn():
    import glob
    #img_pathes = glob.glob("..\\..\\Documents\\dataset\\abcde\\learn\\*.*")
    #img_pathes += glob.glob("..\\..\\Documents\\dataset\\abcde\\learn\\scan\\*.*")
    img_pathes = glob.glob("..\\..\\Documents\\dataset\\abcde\\learn\\scan\\*.*")

    x = np.array([])
    y = np.array([])
    for ip in img_pathes:
        print(ip)
        if "a_" in ip:
            i = 0
        elif "b_" in ip:
            i = 1
        elif "c_" in ip:
            i = 2
        elif "d_" in ip:
            i = 3
        elif "e_" in ip:
            i = 4
        else:
            raise("Unexpected File is Loaded")


        imgs, _ =  get_candidate(ip)
        nimgs = np.array(imgs)
        x = np.append(x, nimgs)
        print(len(x))
        y = np.append(y, np.array(len(imgs) * [i]))
        print(len(y))
    x = np.reshape(x, (-1,28,28))
    seed = 123
    (X_train, X_valid, Y_train, Y_valid) = train_test_split(x, y, test_size=0.15, random_state=seed)
    #learn_cnn(x, X_valid, y, Y_valid)
    learn_cnn(x, X_valid, y, Y_valid)
    py = load_predict(x)
    for a,b,c in zip(x,y,py):
        if not b == np.argmax(c):
            print((b,c))
            plt.imshow(a, cmap='gray')
            plt.show()

def learn_cnn(X_train, X_test, Y_train, Y_test):
    t = time.time()
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
    X_test  = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

    #X_train = X_train / 255
    #X_test  = X_test / 255
    # one hot encode outputs
    Y_train = np_utils.to_categorical(Y_train)
    Y_test  = np_utils.to_categorical(Y_test)

    num_classes = Y_test.shape[1]


    # create model
    model = Sequential()
    model.add(Conv2D(64, (5, 5), input_shape=(28, 28, 1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(4, 4)))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(5, activation='softmax'))

    # Compile model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=500, batch_size=200, verbose=2)

    # Final evaluation of the model
    scores = model.evaluate(X_test,Y_test, verbose=0)
    print("CNN Error: %.2f%%" % (100-scores[1]*100))

    model.save('weights.model')
    print(time.time()-t)

def get_candidate(img_path):
    def is_checked(checked, test_point,r):
        for c in checked:
            if np.linalg.norm(c - test_point) < r:
                return True
        return False

    im = imread(img_path)
    gray = rgb2gray(im)       # グレイスケールに変換
    whiteRegion = gray > 0.8  # 白部分の雑音除去
    gray[whiteRegion] = 1
    whiteRegion = gray <= 0.8  # 黒部分の雑音除去
    gray[whiteRegion] = 0
    # = median(gray, disk(1))
    #gray = denoise_wavelet(gray) #ごましお雑音除去
    im = gray2rgb(gray)
    cand = search(im)
    print(np.shape(im))
    g_img = np.asarray(gray)
    checked_points = []
    images = []
    ren = 16
    for c in cand:
        xc = int(c[0] + c[2]/2)
        yc = int(c[1] + c[3]/2)
        if len(checked_points) is not 0:
            if is_checked(checked_points, np.array([xc,yc]), 10):
                continue
        image = resize(g_img[yc-ren:yc+ren,xc-ren:xc+ren], (28,28))
        images.append(image)
        #X = np.array([image])

        #plt.imshow(X[0], cmap='gray')
        checked_points.append(np.array([xc,yc]))
        #plt.show()
    return images, checked_points

def test():
    import glob
    img_pathes = glob.glob("..\\..\\Documents\\dataset\\abcde\\test\\*")

    label_to_char = ["a", "b", "c", "d", "e"]
    i = -1
    for ip in img_pathes:
        x = np.array([])
        y = np.array([])
        print(ip)
        imgs, points =  get_candidate(ip)
        nimgs = np.array(imgs)
        x = np.append(x, nimgs)
        print(len(x))
        y = np.append(y, np.array(len(imgs) * [i]))
        print(len(y))

        x = np.reshape(x, (-1,28,28))
        seed = 123
        pre_y = load_predict(x)
        im = Image.open(ip)
        draw = ImageDraw.Draw(im)
        font = ImageFont.truetype("C:\\Windows\\Fonts\\meiryob.ttc", 16)
        for py, p in zip(pre_y, points):
            pst = np.argsort(py)
            max_val = np.max(py)
            am1, am2 = pst[-1], pst[-2]
            if   max_val > 0.95:
                draw.text((p[0]+16, p[1]), label_to_char[am1], fill=(255, 0, 0), font=font)
            elif max_val > 0.8:
                draw.text((p[0]+16, p[1]), label_to_char[am1] +"," + label_to_char[am2] + "?", fill=(255, 0, 0), font=font)
            elif max_val > 0.4:
                draw.text((p[0]+16, p[1]), label_to_char[am1] +"," + label_to_char[am2] + "??", fill=(255, 0, 0), font=font)
            else:
                draw.text((p[0]+16, p[1]), "?", fill=(255, 0, 0), font=font)
        im.save("hand_only_sparse.png")
        im.show()

if __name__ == "__main__":
    #test_selective_search()
    learn()
    test()

機械学習の入門編 とりあえずライブラリを使ってデータを分類してみる

今回は前回設定した環境を使って、何らかのデータを分類する方法を書きます。

とりあえず機械学習に触ってみたい人向けにPythonプログラムを書いて結果が出るまでを解説します。アルゴリズムの詳細やハイパーパラメータの意味は専門書を読んだほうがいいので書きません。(間違っていたら怖いし)プログラムの全文は最後に書いてます。

今回扱う問題

今回は教師あり学習の中の分類問題を扱います。教師あり学習は、すでに知っている入力Xと出力Yの情報からいい感じのXとYの関係を推定して、新しい入力X’が来た時にその関係性を使って正しいY’を求めることができるプログラム[学習器]を生成することが目標になります。画像に何が写ってるかを当てるみたいな問題を分類問題、立地とか家の素材、周辺地域の平均年収とかから家の値段を推定するみたいな問題を回帰問題といいます。今回扱うのは分類問題です。でも画像の分類みたいな楽しそうなやつではなく地味なデータでやります。

-SVMを使って分類してみる-

データ例はこれです。

この点を生成するスクリプトはこちら。

    SIZE = 1000
d = [[0,0,0,0,0,0,0,0,1,0],
[0,1,2,2,2,2,2,2,1,0],
[0,1,2,0,0,0,0,0,1,0],
[0,1,2,0,2,2,2,0,1,0],
[0,1,2,0,0,0,2,0,1,0],
[0,1,2,2,2,2,2,0,1,0],
[0,1,1,1,1,1,1,0,1,0],
[0,1,0,0,0,0,0,0,1,0],
[0,1,0,0,0,0,0,0,1,0],
[0,1,1,1,1,1,1,1,1,0]]
x = []
y = []
ud = np.random.rand #alias
for _ in range(SIZE):
a = ud()*10
b = ud()*10
x.append([a, b])
y.append(d[int(a)][int(b)])
x = np.array(x)
x = np.reshape(x,(-1,2))
y = np.array(y)

1000点生成して訓練用、テスト用にそれぞれ500点ずつ分割します。

ここで、numpyのappendはpythonネイティブの[].appendと比べて滅茶苦茶遅いので使わないようにします。

本来はもう一つ本当のテスト用にデータを分割しないといけないんですが、ややこしいので今回はしません。

今回の機械学習の目標は1つ1つの点が3つのうちどのクラスから生成されたかを推定することです。

Xとyの具体的な数値はこのようになっています。例えばXとyの最初の要素は[12.13, 7.37]の座標に点があってそれは1の領域にあることを表しています。

さっそくですが問題を解くためのアルゴリズムを適当に決めてみましょう。それらはディープラーニングを含めてたくさんの種類があります。一般的なものを使用せる場合はライブラリが存在しているので自分でコードを書く必要はありません。

今回はSVMとニューラルネット(階層が深くなるとディープラーニング)をそれぞれ実装してみます。

各アルゴリズムが何をしてるのかは、ここではとても説明できないので必要に応じてパターン学習と機械学習(PRML)とかの参考書で。

まずscikit-learnライブラリを用いてSVMを実装します。scikit-learnにはSVCとSVRの関数がありますが、今回は分類問題なのでSVCを使用します。回帰問題の場合はSVRです。

from sklearn.svm import SVC
clf = SVC(C=30, kernel="rbf", gamma=0.01,class_weight = "balanced")
clf.fit(X_train, y_train)
y_predict = clf.predict(X_test)

このy_predictに予測したラベルが入っています。

SVCにはハイパーパラメータとしてC,カーネル,ガンマが存在します。これらの変数は自動的に決まるものではなく自分で与えてあげるもので、SVMに限らずほぼすべてのアルゴリズムに存在します。

これらの最適値はデータごとに変わるので、通常はグリッドサーチという作業を行い最適なものを選択します。C, ガンマ, カーネルに様々な値を入力したときの一例は以下の通りです。

linearカーネルは境界線が直線なのに対してrbfカーネルは境界線が曲線なのが分かります。またrbfカーネルのガンマが大きいほうがより境界線が複雑になっています。Cは各領域内にノイズが含まれないので300のほうが結果がよくなるようです。

-ニューラルネットも使ってみる-

先ほどはSVMを用いてデータを分類してみました。次はニューラルネットを用いて分類してみます。

データセットは全く同じです。

Kerasを使ってニューラルネットもモデルを作ります。今回は入力データが2次元、分類するクラスは3クラスなので間に隠れ層としてノード数が5のモデルは以下のスクリプトで作ることができます。

from keras.layers import Input, Dense, Activation, Dropout
from keras.models import Model, Sequential
from keras.wrappers.scikit_learn import KerasClassifier

def make_model():
model = Sequential()
model.add(Dense(5, input_dim=2, activation='relu'))
model.add(Dense(3, activation='softmax'))
adam = optimizers.Adam(lr = 0.001, decay = 0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'],
)
return model

X_train = X_train.reshape((len(X_train), np.prod(X_train.shape[1:])))
y_train = np.reshape(y_train, (np.shape(X_train)[0],1))
clf = KerasClassifier(make_model, batch_size=100)
clf.fit(X_train, y_train, epochs=10000, validation_data=(X_test, y_test),)

このプログラムで入力層のノード数が2,隠れ層のノード数が5,出力層で3クラスの分類を行うモデルが作られます。表にするとこんなモデルです。

早速ですが結果を見てみましょう。

ぜんぜんですね。それでは次は隠れ層をものすごく増やして同じことをやってみましょう。

def make_model():
model = Sequential()
model.add(Dense(200, input_dim=2, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(3, activation='softmax'))
adam = optimizers.Adam(lr = 0.001, decay = 0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'],
)
return model

結果はこうなりました。

先ほどよりはよくなりましたが、もっと良くするにはどうすればいいでしょうか。

原因を探るためにAccuracyとLossのepochごとの経過を見てみましょう。

ニューラルネットでは各ノードの適切な重みを計算するために何回も訓練データを使って学習します。

その回数がepoch数で,訓練データとテストデータそれぞれ見ます。

まずは層の少ないほう。

trainのAccuracyがそもそも低いのでネットワークの表現力が足りないことがわかります。

つまりノード数が少なすぎるということです。

次にノードの多いほうを見てみましょう。

trainのほうはaccuracy,lossともに回数を重ねるにつれて改善されてますが、testのほうは悪化しています。これは典型的な過学習の症状です。これを解消するにはノードを減らす、Dropoutやepoch数を減らす、trainデータを増やすことが考えられます。

実際Dropoutを行うコードを追加すると。まだ過学習気味ですが、結果はよくなります。

def make_model():
model = Sequential()
model.add(Dense(200, input_dim=2, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(200, activation='relu'))
model.add(Dense(3, activation='softmax'))
adam = optimizers.Adam(lr = 0.001, decay = 0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'],
)
return model

結果はこうなりました。

緑の領域の左側で改善が見られます。

また、入力データ数を50倍にするとより正確な結果が出ます。

隠れ層の数を増やしていくとニューラルネットの表現力は増加しますが、計算にかかる時間は増加します。教師データ数25000,ノード数200の隠れ層の数が1から10の時の所要時間は前回使用したGPUサーバとノートパソコンでそれぞれ計測すると次のようになりました。一応スペックをもう一回。

手元のノートパソコン(ideapad 720S)GPUサーバー
CPUCore i7 8550U 1.8Ghz 4コア8スレッドXeon Silver 4116 × 2 2.1Ghz 24コア48スレッド
メモリ8GB128GB
GPUなしTesla V100 × 2 + 1080Ti × 2
OSUbuntu 16.04 DesktopUbuntu 16.04 Server

層が増えるにしたがって時間が増加していますが、増加率はノートパソコンのほうが多い結果になりました。CNNの時はもっと差があったのでモデルが複雑になるほど差が出るのかも?

今回はこれで終了します。本当はここからどのアルゴリズムを使うかや、適切なハイパーパラメータをチューニング、データの前処理をして正しい方法で評価し、一番使えるものを探さなければなりません。

まとめ

今回は本当に使ってみただけで雰囲気程度でした。

次回からは実際の業務に使えそうなものを作ることに挑戦しようと思います。

プログラム全文

import numpy as np
from matplotlib import pyplot as plt
import joblib

def make_data(SIZE):

d = [[0,0,0,0,0,0,0,0,1,0],
[0,1,2,2,2,2,2,2,1,0],
[0,1,2,0,0,0,0,0,1,0],
[0,1,2,0,2,2,2,0,1,0],
[0,1,2,0,0,0,2,0,1,0],
[0,1,2,2,2,2,2,0,1,0],
[0,1,1,1,1,1,1,0,1,0],
[0,1,0,0,0,0,0,0,1,0],
[0,1,0,0,0,0,0,0,1,0],
[0,1,1,1,1,1,1,1,1,0]]


x = []
y = []
ud = np.random.rand #alias
for _ in range(SIZE):
a = ud()*10
b = ud()*10
x.append([a, b])
y.append(d[int(a)][int(b)])
x = np.array(x)
x = np.reshape(x,(-1,2))
y = np.array(y)
joblib.dump(x,"mx.pkl")
joblib.dump(y,"my.pkl")

def main():
x = joblib.load("mx.pkl")
y = joblib.load("my.pkl")
dx = x
#データの一部を表示
fig = plt.figure()
ax = fig.add_subplot(111)
colors = ["#ff0000", "#00ff00", "#0000ff"]
print("Drawing Images")
if len(dx) > 3000:
px = dx[0:3000]
else:
px = dx
for i ,v in enumerate(px):
ax.scatter(v[0], v[1], c=colors[int(y[i])-1], marker='o', alpha = 0.3)
ax.set_title('Dataset')
print("Finish")
plt.show()
plt.close()

from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

seed = 12345
(X_train, X_test, y_train, y_test) = train_test_split(x, y, test_size=0.50, random_state=seed)
for k in ["rbf", "linear"]:
for c in [30,300]:
for g in [0.01, 1]:
clf = SVC(C=c, kernel=k, gamma=g,class_weight = "balanced")
clf.fit(X_train, y_train)
y_predict = clf.predict(X_test)
ac = accuracy_score(y_test, y_predict)
f1 = f1_score(y_test, y_predict, average="macro")
fig = plt.figure()
ax = fig.add_subplot(111)
if len(X_test) > 3000:
px = X_test[0:3000]
else:
px = X_test
colors = ["#ff8800", "#00ff88", "#8800ff"]
print("Data plotting")
for i ,v in enumerate(px):
ax.scatter(v[0], v[1], c=colors[int(y_predict[i])-1], marker='o', alpha = 0.3)
ax.set_title('')
print("Finish")
print([k,c,g])
plt.show()
plt.close()
np.set_printoptions(suppress=True)
np.set_printoptions(threshold=np.inf, precision=2, floatmode='maxprec')
print(ac)
print(f1)

from keras.layers import Input, Dense, Activation, Dropout
from keras.models import Model, Sequential
from keras.wrappers.scikit_learn import KerasClassifier
from keras import optimizers

layers = 0

def make_model():
model = Sequential()
model.add(Dense(200, input_dim=2, activation='relu'))
for _ in range(layers):
model.add(Dense(200, activation='relu'))
model.add(Dense(3, activation='softmax'))
adam = optimizers.Adam(lr = 0.001, decay = 0)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'],
)

#plot_model(model, to_file='model.png', show_shapes=True)
return model

X_train = X_train.reshape((len(X_train), np.prod(X_train.shape[1:])))
y_train = np.reshape(y_train, (np.shape(X_train)[0],1))

for i in range(10):
layers = i
clf = KerasClassifier(make_model, batch_size=100)
history = clf.fit(X_train, y_train, epochs=100, verbose = 0, validation_data=(X_test, y_test))

np.set_printoptions(suppress=True)
np.set_printoptions(threshold=np.inf, precision=2, floatmode='maxprec')
ac = accuracy_score(y_test, y_predict)
f1 = f1_score(y_test, y_predict, average="macro")
print(ac)
print(f1)

# summarize history for accuracy
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# summarize history for loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

y_predict = clf.predict(X_test)

fig = plt.figure()
ax = fig.add_subplot(111)
if len(X_test) > 3000:
px = X_test[0:3000]
else:
px = X_test
colors = ["#ff8800", "#00ff88", "#8800ff"]
for i ,v in enumerate(px):
ax.scatter(v[0], v[1], c=colors[int(y_predict[i])-1], marker='o', alpha = 0.3)
ax.set_title('Test plot')
print("Finish")
plt.show()
plt.close()


if __name__ =="__main__":
make_data(50000)
main()

[初心者用]Pythonの環境構築とSSHとCUDAを使ってGPUサーバーで機械学習をする方法

こんにちは、はじめまして。Link-Uの町屋敷です。

技術ブログを立ち上げるとのことで、ちょうど機械学習の環境を整える必要があったのでその詳細を書きます。環境構築って初投稿っぽいですし。

次にやる人のためのときのメモも兼ねてるからできるだけ色んなサイトに行かなくていいようにしたら結構長くなった。どこか間違ってたらごめんなさい。

折角なので最後に手元のノートパソコンとGPUサーバーの両方で実行して手書き文字認識のベンチマークを取ってみます。

環境

手元のノートパソコン(ideapad 720S)GPUサーバー
CPUCore i7 8550U 1.8Ghz 4コア8スレッドXeon Silver 4116 × 2 2.1Ghz 24コア48スレッド
メモリ8GB128GB
GPUなしTesla V100 × 2 + 1080Ti × 2
OSUbuntu 16.04 DesktopUbuntu 16.04 Server

Pythonを使う準備

今回はIDEとしてEclipse、環境構築用にAnacondaを使います。PythonのIDEでEclipse使う人は少数派らしいけど。

Eclipseの設定

公式サイトのダウンロードボタンを押す。

ファイルを保存を選択して、ダウンロードしたフォルダに行き右クリックをして「ここに展開する」を選択。
出来上がったeclipse-installerフォルダ内のeclipse-instをダブルクリックする。

今回はjreがないと怒られたから取りに行きます。
Oracleの公式サイトからjreをダウンロード

規約の同意(Accept License Agreement)をチェックしないとダウンロードできないので注意

先ほどと同様にダウンロードしてきたファイルを展開して、生成されたjreから始まるフォルダをjreにリネームしてeclipse-installerの中に置けばeclipse-instが通るようになる。

pythonが選択肢にないのでとりあえずEclipse IDE for java Developpersを選択、フォルダを選択してインストール。インストールが完了したらLaunchボタンを押して起動。
ここでWorkspaceの選択をする。選択したフォルダ内にpythonのプロジェクトを置いていくことになる。

起動したら上のHelpタブからEclipse Marketplaceを選択し、pydevを検索。

installボタンからインストールする。インストールが完了すると、Eclipseを再起動するか聞かれるので再起動
初めはWelcomeタブが出ていると思うのでそれをxを押して消す。右上のOpenPerspectiveボタンを押して表示されるウィンドウにPydevが含まれていれば成功。

Anacondaの設定

仮想環境を使わないとpythonのバージョンが変わった時やバグった時に最悪OSから入れなおしになる事件が発生するかもしれないので、Anacondaを使って仮想環境を使えるようにする。
Anacondaを公式サイトからダウンロード。
よほどのことがない限りpython3.6バージョンでいい。
ターミナル(Ctr+Alt+T)に行って以下を実行

sudo sh '/home/user/ダウンロード/Anaconda3-5.1.0-Linux-x86_64.sh'

*userは人によって違う

途中利用規約やインストールパスを聞かれる。

 Do you wish the installer to prepend the Anaconda3 install location
to PATH in your /home/m/.bashrc

はyes Microsoft VSCodeのインストールはEclipseを使うならいらない。
インストールが終わったら一旦再起動
再起動したらターミナルを開いて仮想環境を作る

conda create -n python3.6-env

-n の後ろに環境名をかく

pythonのバージョンを環境名に入れておくと複数の環境を作った時にバージョンを間違えないから便利

source activate python3.6-env

で環境に入る。するとターミナルのパスの前に環境名がつく

which pythonと打つと現在使用しているpython.exeの場所がわかる。

機械学習によく使うパッケージをインストール

conda install numpy scipy sikit-learn keras pandas matplotlib

その後

python
import numpy, scipy, pandas, sklearn, matplotlib, keras

と打ってエラーが出なければ成功。Ctrl+Dで抜けます。

パッケージをインストールをインストールする方法にはpipとconda (とgit)があるけど基本的にpipとconda両方あるパッケージはcondaを使ったほうが良さそう。
pipはソースからコンパイルするものもあるからハマった時はCやFORTRANなどの闇と戦わなければならんらしい。

EclipseでAnacondaの仮想環境でインストールしたパッケージを使う

先にインストールしたEclipseのpydevで普通にpythonスクリプトを書いても反映されない。

反映させるには、まずeclipseを開き、penPerspectiveボタンを押して表示されるウィンドウにのPydevを選択してopenする。すると右上にpythonマークが出てくるのでこれを選択し、上のFile->New->PydevProjectを選択すると以下のようなウィンドウが出てくるのでプロジェクト名を適当に決めて真ん中付近のplease configure an interpreter before proceedingをクリックしManual Configを選択

その後右のNewからInterpreterNameを適当に決め、Browseを選択し、Anacondaの仮想環境内でwhich pythonを打った時に出てくるpythonを登録する。あとはOKを2回押してApply.Interpreterが変更したものになってることを確認してFinishを押す。

できたか確認しよう。
File→New→Fileから作ったプロジェクトを選択し、ファイルを名前の最後が.pyになるように生成。
短い名前はライブラリにすでに使われているファイル名と被った時めんどくさいことになるので避ける。

GPUサーバーの設定

このままでも機械学習を始められるが、ノートパソコンなどの普通の性能のパソコンだとビッグデータを使おうとすると処理にかなり時間がかかってその間パソコンが使えなくなったり、そもそもメモリが足りなくて動かないことのないよう小さいデータセットでテストしたら処理を別のものにやらせたほうがいい。今回はGPUサーバーがあるのでそれを使う。

サーバーといってもOSを入れれば普通のパソコンと同様に設定できる。OSはUSBからUbuntu16.04をインストールした。SSHをつなげるまではサーバーに画面とキーボードを直接つなげて設定した。繋げずにやる方法もあるらしいが試してない。

ログイン画面でAlt+Ctr+F1でCUIに入れる。OSにUbuntu Serverを選んだなら元々CUI。

SSHでサーバーと接続する

GPUサーバーは起動時に100dB超えるくらいの爆音であまり近づいて作業したくないのでSSHで全部やってしまう。

ここからサーバー側とクライアント側の操作がどっちがどっちかややこしくなるので行の先頭にサーバー側なら[S],クライアント側なら[C]をつける。

まずサーバーのIPをチェック

[S]ifconfig

inetアドレス:192.168.101.48 と書かれているこれがサーバーのIPアドレス。

長く使うならIPは固定したほうがいい。

[C]sudo apt-get install ssh
[C]ssh 192.168.101.48 -l [サーバー名]

SSHをインストールして、サーバーと接続。パスワードを聞かれるので、サーバーのユーザーのログイン時のパスワードを入れる。

こうなったらサーバーとの接続は成功。このターミナルともう一つターミナルを開くことでクライアント側のパソコンだけで処理が全てできる。

CUDAの設定

デフォルトのドライバではパフォーマンスが出ないので、NVIDIAから搭載しているGPU専用のドライバを取ってくる。

基本NVIDIAの公式インストールガイドに沿っていくがこの工程で結構バグが出て戦ったので最終的にできたやつを書いておく。公式の存在を知ったのが結構戦ったあとだったのが長引いた原因かも。

下手すると最悪OS再インストールになるので大事なデータなどは退避させたほうがいい。

sudo apt-get purge cuda*
sudo apt-get purge nvidia*
dpkg -l | grep nvidia
sudo wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
sudo dpkg -i cuda-repo-ubuntu1604_8.0.61-1_amd64.deb
sudo apt-get update
sudo apt-get upgrade
sudo apt-get install cuda
sudo nvidia-xconfig
export CUDA_HOME=/usr/local/cuda-9.1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${CUDA_HOME}/lib64
export PATH=$PATH:{$CUDA_HOME}/bin
sudo reboot
nvidia-smi

4, 5行目は人によって変える必要あり

最後にこんなのが出てきたらドライバのインストールはできている。

cuDNNの設定

これも公式サイトを参考にする。

最初にcuDNNの公式サイトからcuDNNをダウンロード。画面が必要なのでクライアント側でダウンロードしてサーバーに送りつけることにする。Nvidia Developer Programに登録しないとダウンロードできない。

[C] sudo scp '/home/[クライアントユーザー名]/Downloads/cudnn-9.1-linux-x64-v7.1.tgz' [サーバー名]@[サーバーIPアドレス]:/home/[サーバーユーザー名]
[S] sudo tar -xzvf ../../home/[サーバーユーザー名]/cudnn-9.1-linux-x64-v7.1.tgz
[S] sudo cp -P cuda/include/cudnn.h /usr/local/cuda-9.1/include
[S] sudo cp -P cuda/lib64/libcudnn* /usr/local/cuda-9.1/lib64/
[S] sudo chmod a+r /usr/local/cuda-9.1/lib64/libcudnn*

scpでファイルを送りつけてその後は公式サイトをなぞるだけの作業。cudaとかcuDNNのバージョンだけ注意

これでサーバーでプログラムを動かせるが、まだanacondaのライブラリたちが使えない。

‘/home/user/Downloads/Anaconda3-5.1.0-Linux-x86_64.sh’をサーバーに送りつけて先ほどと同じことをすれば使えるようになる。

[C] sudo scp '/home/[クライアントユーザー名]/Downloads/Anaconda3-5.1.0-Linux-x86_64.sh' [サーバー名]@[サーバーIPアドレス]:/home/[サーバーユーザー名
[S] sudo sh ./Anaconda3-5.1.0-Linux-x86_64.sh 以下同様

パッケージが増えたらAnacondaの環境を複製する方法を使うと楽そう。

サーバー側ではGPUを使うのでKeras-GPUとtensorflow-GPUもインストールする。

テスト

ついに環境ができたのでテスト。

使用したのはKerasライブラリのCNNを使った数字識別のサンプルコード mnist_cnn.pyに時間を計るコードだけを追加したもの。

試しに手元のノートパソコンideapadの結果はこうなった。

accuracyが正解率でtimeがかかった秒数、大体46分かかっている。

次にサーバーでやってみよう、サンプルをサーバーに送り、走らせる、

[S] source activate [環境名]
[S] python [ファイルを追加したフォルダのパス]mnist_cnn.py

結果はこれ。

手元のノートパソコン(ideapad 720S)GPUサーバー
Time2753.96s100.28s
Accuracy0.99130.9917

ちょうど100秒で終わった。速度は27.5倍くらい速くなっているだけでなくaccuracyも若干上がっている。1回しかテストしてないのでたまたまかもしれないが、Teslaのほうが倍精度をより正確に計算できるので、それが影響してるのかもしれない。ちなみにTeslaは2基搭載されているが、簡単のためそのうち1つしか使っていない。

まとめ

今回はGPUサーバーでPythonを使った機械学習を行う準備方法の解説でした。

次回からの自分の担当分はいろんな機械学習の手法を手元のデータに使っていって、そのやり方や結果を書いていきたいです。