概要
FaxOCRという手書き数字認識の問題に挑戦した。 mnistで学習させたCNNでテストデータを判別すると正答率70%弱と低かった。 FaxOCRのデータだけで学習させたCNNでは96%程度の正答率だった。 mnistのデータとFaxOCRのデータはどうも違うようだけど、何が違うのかよくわからない。
以下はipython notebookの出力をちょこちょこいじったので、変なところがいくつかある。
はじめに
ツイッターを眺めているとこんなツイートを見つけた。
どうもFaxで送られてきた手書き数字を認識したいらしい。 はじめから電子データでいいんちゃうかとか、お役所も色々大変なんだろうなぁとか思いつつ。 手書き数字認識とかCNNに突っ込んだら終わりっしょぐらいの感じで始めた。先日公開したMNIST形式の手書き文字データですが、より精度の高い前処理を行った版を公開しました。それに合わせて、トレーニングデータも今までの4倍量を公開しました。元画像も公開中。文字認識タスクに興味のある方、お気軽にお試し下さい。https://t.co/wDq0lx59SL
— Takashi Okumura (@tweeting_drtaka) 2016年4月13日
FaxOCR
サイトはこちら
バイナリ形式で画像とラベルのセットが用意されている。mnistと同じ形式らしい。 やり始めた当時は学習データが1711画像だった(たぶん)。 mnistが70000枚とかなのを考えても、CNNに入れるにはだいぶ少ないなぁという印象だったので、mnistで学習させたCNNを使ってFaxOCRのテストデータを分類することにした。
mnistで学習させたCNN
mnistでCNNを学習させるというのはTensorFlowのチュートリアルにあるぐらい鉄板なので、ググるとたくさんヒットする。個人的にchainerの使い方なら、なんとなく習得しているので、こちら のchainerの実装を使わせてもらうことにした。 CNNやらのコードは完全にコピペなのでここには書かないが、エポック数だけ20から50に変更した。
学習結果はtest_accuracy=0.694779とむっちゃ低かった。 以下はエポック毎のaccuracy(青)と誤差関数(緑)。学習は進んでいるのに、accuracyが上がっていない。
%matplotlib inline import pylab as plt import pandas as pd accuracy = pd.read_csv("./accuracy_mnist2fax.txt", sep="\t") loss = pd.read_csv("./loss_mnist2fax.txt", sep="\t") fig = plt.figure(figsize=[10,10]) accuracy["test_accuracy"].plot() plt.ylim([0,1]) loss["train_loss"].plot(secondary_y=True) plt.title("mnist->faxor") plt.show()
mnistのデータをテストデータにすると以下のようになる。 lossは同じぐらいなのに、accuracyは1エポックの時点で0.98を超えている。
accuracy = pd.read_csv("./accuracy_mnist2mnist.txt", sep="\t") loss = pd.read_csv("./loss_mnist2mnist.txt", sep="\t") fig = plt.figure(figsize=[10,10]) accuracy["test_accuracy"].plot() plt.ylim([0,1])#plt.ylim([0,1]) loss["train_loss"].plot(secondary_y=True) plt.title("mnist->mnist") plt.show()
一応、用意された学習データで学習すると以下のようになって、やっぱりデータ足りてない感がある。
accuracy = pd.read_csv("./accuracy_fax2fax.txt", sep="\t") loss = pd.read_csv("./loss_fax2fax.txt", sep="\t") fig = plt.figure(figsize=[10,10]) accuracy["test_accuracy"].plot() plt.ylim([0,1]) loss["train_loss"].plot(secondary_y=True) plt.title("faxor->faxor") plt.show()
データを眺める
手書き数字といえばmnistでおっけーというイメージだったので、ちょっとショックだった。 精度が出ていない原因として、データが質的に異なっているという可能性があるので、色々と可視化してみることにした。
画像データ
何はともあれ学習データの画像を見る。以下は可視化のコード。全部表示すると多すぎるので、適当に10件ずつだけ。
%matplotlib inline from read_data import read,show from sklearn.datasets import fetch_mldata import numpy as np mnist = fetch_mldata('MNIST original', data_home=".") X = mnist.data y = mnist.target X = X.astype(np.float32) y = y.astype(np.int32) X /= X.max() X_train = X y_train = y data = read() test_data = read(dataset="testing") X_test = test_data[0] y_test = test_data[1] X_test = X_test.reshape(len(y_test), -1) X_test = X_test / float(X_test.max()) import random indecies = random.sample(range(len(X_train)), 1000) for i in range(10): show(X_test[i].reshape(28,28)) print "###################################################" print "############## mnist ここから######################" print "###################################################" for i in range(10): show(X_train[indecies[i]].reshape(28,28))
mnist ここから
ぱっと見た感じ、mnistのほうが太い線で書かれたものが多い気がする。とはいえ、細い線のものもあるし、認識してくれてもいいんじゃないかという印象。
t-SNEによる可視化
なんかよくわからんときは、とりあえずt-SNEにぶちこむというのが最近のマイブーム。 このあたりが詳しい。 これとかをみると、PCAよりええんちゃうって思う。 scikit-learnに関数があるので、使うのはとても簡単。mnist全データを突っ込むとメモリが足りないと怒られたので、ランダムに1000点選んで描画した。
data = np.r_[X_train[indecies], X_test] from sklearn.manifold import TSNE model = TSNE(n_components=2) tsned = model.fit_transform(data)
import pylab as plt label = np.r_[["b" for i in X_train[:1000]], ["r" for i in X_test]] plt.figure(figsize=(30,30)) plt.scatter(tsned[:,0], tsned[:,1], c=label, linewidths=0) plt.show()
青(mnist)と赤(FaxOCR)のデータが明らかに分離している。こりゃあかんわって感じ。
データを増やしてCNN
そうこうしているうちに、こちらに精度を抜かれてしまっていた。
"適当に拡大縮小や回転をして画像データの枚数を11倍に(1枚から10枚生成)しました。"
とあって、すごく妥当だと思うし、なんで自分はこんなわけわからんことやってんだろうと思う。 とはいえ、精度で負けてるのはなんか悔しいので48x48のデータを51倍にしてCNNに突っ込んだ。
一応画像を適当に増やすところのコードは以下。
size = tuple(np.array([X_train[0].shape[1], X_train[0].shape[0]])) new_imgs = [] new_labels = [] import random import cv2 from itertools import izip for img, label in izip(X_train, y_train): for i in range(50): rad = (random.random() - 0.5) * 0.5 pos1 = (random.random() - 0.5) * 5 pos2 = (random.random() - 0.5) * 5 mat = np.float32([[np.cos(rad), -1 * np.sin(rad), pos1], [np.sin(rad), np.cos(rad), pos2]]) dst = cv2.warpAffine(img, mat, size, flags=cv2.INTER_LINEAR) new_imgs.append(dst) new_labels.append(label) X_train = np.r_[X_train, new_imgs] y_train = np.r_[y_train, new_labels]
accuracy(青)と誤差関数(緑)は以下。とりあえず96%とかいっているのでまあまあというところ。
accuracy = pd.read_csv("./accuracy_fax2fax_copied_48.txt", sep="\t") loss = pd.read_csv("./loss_fax2fax_copied_48.txt", sep="\t") fig = plt.figure(figsize=[10,10]) accuracy["test_accuracy"].plot() plt.ylim([0,1]) loss["train_loss"].plot(secondary_y=True) plt.title("faxor48->faxor48") plt.show()
ちなみに間違っていた画像は以下。これはしょうがないんじゃないかと思うものが多い。(というか学習データのほうは大丈夫なんだろうか…)
%matplotlib inline import cPickle as pickle import numpy as np import chainer from chainer import cuda import chainer.functions as F xp=np def forward(x_data, y_data): x, t = chainer.Variable(x_data), chainer.Variable(y_data) h = F.max_pooling_2d(F.relu(model.conv1(x)), 2) h = F.max_pooling_2d(F.relu(model.conv2(h)), 2) h = F.dropout(F.relu(model.l1(h)), train=False) y = model.l2(h) return y, t,F.accuracy(y,t) with open("model_cnn_48.pkl", 'rb') as i: model = pickle.load(i) from read_data import read, show test_data = read(dataset="testing", size=48) X_test = test_data[0].astype(xp.float32) y_test = test_data[1].astype(xp.int32) X_test = X_test.reshape(len(y_test), -1) X_test = X_test / float(X_test.max()) X_test = X_test.reshape((len(X_test), 1, 48, 48)) y,t,acc = forward(X_test, y_test) print "#################################" print "accuracy: " + str(acc.data) print "#################################" from itertools import izip for i,(temp_y,temp_t,temp_X_test) in enumerate(izip(y.data,t.data, X_test)): if np.argmax(temp_y)!=temp_t: print "No.%d 正解:%d 出力:%d (%s)"%(i,temp_t, np.argmax(temp_y),temp_y) show(temp_X_test.reshape(48,48))
accuracy: 0.967871487141
No.14 正解:9 出力:8 ([ 3.53644633 -66.54345703 -44.21648026 -30.54708862 -47.66264725 -61.4640274 -36.32198715 -54.34354401 28.37440872 14.29025173])
No.116 正解:9 出力:5 ([-35.68978119 -23.23931885 -76.64511108 -26.64510727 -42.89046097 50.29698944 -16.02383232 -52.14741135 -22.0790844 -29.76072311])
No.127 正解:9 出力:8 ([ -0.36938047 -72.61532593 -38.76506424 -25.25551796 -40.20273972 -45.48267365 -41.19197464 -39.32862854 26.44895935 11.87366581])
No.170 正解:7 出力:1 ([-13.41542912 8.12284565 -28.43426895 -6.35744619 -42.47330093 -31.51161766 -36.15800858 -0.18305674 -23.2899704 -10.80175686])
No.172 正解:9 出力:1 ([-18.31829453 8.41508961 -24.00697708 -18.74674988 -15.38268757 -19.12284851 -21.66376686 -29.5259037 -9.97081757 5.56778383])
No.196 正解:9 出力:4 ([-25.27557564 -20.27404976 -33.58036041 -38.26721573 18.712677 -17.12530899 -26.8935318 -34.70022964 -20.10196686 11.45114803])
No.231 正解:5 出力:8 ([ -3.76606822 -27.21845436 -40.48172379 -41.74909592 -20.90390015 -27.25065613 8.70675373 -39.89279938 28.23112106 -19.23669815])
No.247 正解:2 出力:4 ([-15.3788166 -8.5814476 -14.29022121 -16.61594963 5.68172646 -8.83567333 0.14922404 -22.43881035 -17.46813393 -14.75987816])
mnistデータとFaxOCRデータの違い
データを公開された方の本来の目的からすると、パラメータチューニングとかして性能を上げたほうがいいのかもしれないけど、正直そっちにはあまり興味がない。Kaggleガチ勢の方がこういうの とか出してくれているので、参考にするといいのかもしれない。
個人的に気になったのは今回のデータとmnistの違い。 見た目は同じような手書き文字なのに、t-sneで可視化すると明らかに分離している。 Faxのデータはどうも線の細さを統一したり、回転させたりと前処理を結構しているらしいので、それが影響しているのかなと思った。 なので、生データを可視化してみる。FaxOCRの画像サイズはまちまちだったので、ImageMagickで28x28にリサイズした。アスペクト比は保存していないのでややまずい。
import cv2 as cv import glob mnist = fetch_mldata('MNIST original', data_home=".") X = mnist.data y = mnist.target X = X.astype(np.float32) y = y.astype(np.int32) X /= X.max() X_train = X y_train = y X_test = [] y_test = [] for img_file in glob.glob("./data/mustread/28/*.png"): y_test.append(int(img_file[16])) X_test.append(255 - cv.imread(img_file, flags=0)) data_row = np.r_[X_train[indecies], np.array(X_test).reshape(len(X_test),-1)] from sklearn.manifold import TSNE model = TSNE(n_components=2) tsned_row = model.fit_transform(data_row)
label = np.r_[["b" for i in range(1000)], ["r" for i in X_test]] plt.figure(figsize=(30,30)) plt.scatter(tsned_row[:,0], tsned_row[:,1], c=label, linewidths=0) plt.show()
なんでや…
ミス訂正(2016/4/26)
FaxOCRのほうが0~1に補正されていなかった。正しいt-SNEの結果は以下。 これなら、線が細いやつが固まっていると思えば、まだありえる(?)
一旦終わり
正直なにがダメなのかよくわかってない。というかバグなんじゃないのか、僕のコードがなにかやらかしてるんじゃないのか。 見た目おんなじに見えるんだけど、何か本当に違うのか… 誰かバグとか根本的な間違いとか見つけたら教えてください…
ミス訂正(2016/4/26)
コードにミスがあったので、元画像データはmnistの一部がある空間にありそうだとわかった。時間があるときにもうちょっとちゃんとやって追記します。
続き
参考
- テストデータ (MNIST IDX形式) - Shinsai FaxOCR
- Chainerによる畳み込みニューラルネットワークの実装 - 人工知能に関する断創録Chainerによる畳み込みニューラルネットワークの実装 - 人工知能に関する断創録
- t-SNE を用いた次元圧縮方法のご紹介 | ALBERT Official Blogt-SNE を用いた次元圧縮方法のご紹介 | ALBERT Official Blog
- Digit Recognizer | Kaggle
- FaxOCR を CNN でやってみた | ZABURO app
- kaggle-digit-recognizer/2_model.lua at master · toshi-k/kaggle-digit-recognizer · GitHub