FaxOCR手書き数字データの認識 その1

概要

FaxOCRという手書き数字認識の問題に挑戦した。 mnistで学習させたCNNでテストデータを判別すると正答率70%弱と低かった。 FaxOCRのデータだけで学習させたCNNでは96%程度の正答率だった。 mnistのデータとFaxOCRのデータはどうも違うようだけど、何が違うのかよくわからない。

以下はipython notebookの出力をちょこちょこいじったので、変なところがいくつかある。

はじめに

ツイッターを眺めているとこんなツイートを見つけた。

どうもFaxで送られてきた手書き数字を認識したいらしい。 はじめから電子データでいいんちゃうかとか、お役所も色々大変なんだろうなぁとか思いつつ。 手書き数字認識とかCNNに突っ込んだら終わりっしょぐらいの感じで始めた。

FaxOCR

サイトはこちら

バイナリ形式で画像とラベルのセットが用意されている。mnistと同じ形式らしい。 やり始めた当時は学習データが1711画像だった(たぶん)。 mnistが70000枚とかなのを考えても、CNNに入れるにはだいぶ少ないなぁという印象だったので、mnistで学習させたCNNを使ってFaxOCRのテストデータを分類することにした。

mnistで学習させたCNN

mnistでCNNを学習させるというのはTensorFlowのチュートリアルにあるぐらい鉄板なので、ググるとたくさんヒットする。個人的にchainerの使い方なら、なんとなく習得しているので、こちら のchainerの実装を使わせてもらうことにした。 CNNやらのコードは完全にコピペなのでここには書かないが、エポック数だけ20から50に変更した。

学習結果はtest_accuracy=0.694779とむっちゃ低かった。 以下はエポック毎のaccuracy(青)と誤差関数(緑)。学習は進んでいるのに、accuracyが上がっていない。

%matplotlib inline
import pylab as plt
import pandas as pd

accuracy = pd.read_csv("./accuracy_mnist2fax.txt", sep="\t")
loss = pd.read_csv("./loss_mnist2fax.txt", sep="\t")
fig = plt.figure(figsize=[10,10])

accuracy["test_accuracy"].plot()
plt.ylim([0,1])
loss["train_loss"].plot(secondary_y=True)
plt.title("mnist->faxor")
plt.show()

f:id:ksknw:20160424232539p:plain

mnistのデータをテストデータにすると以下のようになる。 lossは同じぐらいなのに、accuracyは1エポックの時点で0.98を超えている。

accuracy = pd.read_csv("./accuracy_mnist2mnist.txt", sep="\t")
loss = pd.read_csv("./loss_mnist2mnist.txt", sep="\t")
fig = plt.figure(figsize=[10,10])
accuracy["test_accuracy"].plot()
plt.ylim([0,1])#plt.ylim([0,1])
loss["train_loss"].plot(secondary_y=True)
plt.title("mnist->mnist")
plt.show()

f:id:ksknw:20160424232639p:plain

一応、用意された学習データで学習すると以下のようになって、やっぱりデータ足りてない感がある。

accuracy = pd.read_csv("./accuracy_fax2fax.txt", sep="\t")
loss = pd.read_csv("./loss_fax2fax.txt", sep="\t")
fig = plt.figure(figsize=[10,10])
accuracy["test_accuracy"].plot()
plt.ylim([0,1])
loss["train_loss"].plot(secondary_y=True)
plt.title("faxor->faxor")
plt.show()

f:id:ksknw:20160424232651p:plain

データを眺める

手書き数字といえばmnistでおっけーというイメージだったので、ちょっとショックだった。 精度が出ていない原因として、データが質的に異なっているという可能性があるので、色々と可視化してみることにした。

画像データ

何はともあれ学習データの画像を見る。以下は可視化のコード。全部表示すると多すぎるので、適当に10件ずつだけ。

%matplotlib inline
from read_data import read,show
from sklearn.datasets import fetch_mldata
import numpy as np

mnist = fetch_mldata('MNIST original', data_home=".")
X = mnist.data
y = mnist.target
X = X.astype(np.float32)
y = y.astype(np.int32)

X /= X.max()
X_train = X
y_train = y

data = read()

test_data = read(dataset="testing")
X_test = test_data[0]
y_test = test_data[1]
X_test = X_test.reshape(len(y_test), -1) 
X_test = X_test / float(X_test.max())

import random
indecies = random.sample(range(len(X_train)), 1000)

for i in range(10):
    show(X_test[i].reshape(28,28))
    
print "###################################################"
print "############## mnist ここから######################"
print "###################################################"
for i in range(10):
    show(X_train[indecies[i]].reshape(28,28))

f:id:ksknw:20160424232712p:plain

f:id:ksknw:20160424232715p:plain

f:id:ksknw:20160424232728p:plain

f:id:ksknw:20160424232743p:plain

f:id:ksknw:20160424232753p:plain

f:id:ksknw:20160424232803p:plain

f:id:ksknw:20160424232808p:plain

f:id:ksknw:20160424232813p:plain

f:id:ksknw:20160424232819p:plain

f:id:ksknw:20160424232827p:plain


mnist ここから


f:id:ksknw:20160424232847p:plain

f:id:ksknw:20160424232852p:plain

f:id:ksknw:20160424232903p:plain

f:id:ksknw:20160424232932p:plain

f:id:ksknw:20160424232927p:plain

f:id:ksknw:20160424232908p:plain

f:id:ksknw:20160424232912p:plain

f:id:ksknw:20160424232918p:plain

f:id:ksknw:20160424232921p:plain

f:id:ksknw:20160424232924p:plain

ぱっと見た感じ、mnistのほうが太い線で書かれたものが多い気がする。とはいえ、細い線のものもあるし、認識してくれてもいいんじゃないかという印象。

t-SNEによる可視化

なんかよくわからんときは、とりあえずt-SNEにぶちこむというのが最近のマイブーム。 このあたりが詳しい。 これとかをみると、PCAよりええんちゃうって思う。 scikit-learnに関数があるので、使うのはとても簡単。mnist全データを突っ込むとメモリが足りないと怒られたので、ランダムに1000点選んで描画した。

data = np.r_[X_train[indecies], X_test]

from sklearn.manifold import TSNE
model = TSNE(n_components=2)

tsned = model.fit_transform(data)
import pylab as plt
label = np.r_[["b" for i in X_train[:1000]], ["r" for i in X_test]]
plt.figure(figsize=(30,30))
plt.scatter(tsned[:,0], tsned[:,1], c=label, linewidths=0)
plt.show()

f:id:ksknw:20160424232933p:plain

青(mnist)と赤(FaxOCR)のデータが明らかに分離している。こりゃあかんわって感じ。

データを増やしてCNN

そうこうしているうちに、こちらに精度を抜かれてしまっていた。

"適当に拡大縮小や回転をして画像データの枚数を11倍に(1枚から10枚生成)しました。"

とあって、すごく妥当だと思うし、なんで自分はこんなわけわからんことやってんだろうと思う。 とはいえ、精度で負けてるのはなんか悔しいので48x48のデータを51倍にしてCNNに突っ込んだ。

一応画像を適当に増やすところのコードは以下。

size = tuple(np.array([X_train[0].shape[1], X_train[0].shape[0]]))

new_imgs = []
new_labels = []
import random
import cv2
from itertools import izip
for img, label in izip(X_train, y_train):
    for i in range(50):
        rad = (random.random() - 0.5) * 0.5
        pos1 = (random.random() - 0.5) * 5
        pos2 = (random.random() - 0.5) * 5
        mat = np.float32([[np.cos(rad), -1 * np.sin(rad), pos1],
                          [np.sin(rad), np.cos(rad), pos2]])
        dst = cv2.warpAffine(img, mat, size, flags=cv2.INTER_LINEAR)
        new_imgs.append(dst)
        new_labels.append(label)
X_train = np.r_[X_train, new_imgs]
y_train = np.r_[y_train, new_labels]

accuracy(青)と誤差関数(緑)は以下。とりあえず96%とかいっているのでまあまあというところ。

accuracy = pd.read_csv("./accuracy_fax2fax_copied_48.txt", sep="\t")
loss = pd.read_csv("./loss_fax2fax_copied_48.txt", sep="\t")
fig = plt.figure(figsize=[10,10])
accuracy["test_accuracy"].plot()
plt.ylim([0,1])
loss["train_loss"].plot(secondary_y=True)
plt.title("faxor48->faxor48")
plt.show()

f:id:ksknw:20160424234032p:plain

ちなみに間違っていた画像は以下。これはしょうがないんじゃないかと思うものが多い。(というか学習データのほうは大丈夫なんだろうか…)

%matplotlib inline

import cPickle as pickle
import numpy as np
import chainer
from chainer import cuda
import chainer.functions as F

xp=np


def forward(x_data, y_data):
    x, t = chainer.Variable(x_data), chainer.Variable(y_data)
    h = F.max_pooling_2d(F.relu(model.conv1(x)), 2)
    h = F.max_pooling_2d(F.relu(model.conv2(h)), 2)
    h = F.dropout(F.relu(model.l1(h)), train=False)
    y = model.l2(h)

    return y, t,F.accuracy(y,t)
    
with open("model_cnn_48.pkl", 'rb') as i:
    model = pickle.load(i)
    
from read_data import read, show
test_data = read(dataset="testing", size=48)
X_test = test_data[0].astype(xp.float32)
y_test = test_data[1].astype(xp.int32)
X_test = X_test.reshape(len(y_test), -1)
X_test = X_test / float(X_test.max())
X_test = X_test.reshape((len(X_test), 1, 48, 48))

y,t,acc = forward(X_test, y_test)
print "#################################"
print "accuracy: " + str(acc.data)
print "#################################"

from itertools import izip

for i,(temp_y,temp_t,temp_X_test) in enumerate(izip(y.data,t.data, X_test)):
    if np.argmax(temp_y)!=temp_t:
        print "No.%d 正解:%d 出力:%d (%s)"%(i,temp_t, np.argmax(temp_y),temp_y)
        show(temp_X_test.reshape(48,48))
    accuracy: 0.967871487141
    No.14 正解:9 出力:8 ([  3.53644633 -66.54345703 -44.21648026 -30.54708862 -47.66264725 -61.4640274  -36.32198715 -54.34354401  28.37440872 14.29025173])

f:id:ksknw:20160424234123p:plain

No.116 正解:9 出力:5 ([-35.68978119 -23.23931885 -76.64511108 -26.64510727 -42.89046097
      50.29698944 -16.02383232 -52.14741135 -22.0790844  -29.76072311])

f:id:ksknw:20160424234135p:plain

    No.127 正解:9 出力:8 ([ -0.36938047 -72.61532593 -38.76506424 -25.25551796 -40.20273972
     -45.48267365 -41.19197464 -39.32862854  26.44895935  11.87366581])

f:id:ksknw:20160424234144p:plain

    No.170 正解:7 出力:1 ([-13.41542912   8.12284565 -28.43426895  -6.35744619 -42.47330093
     -31.51161766 -36.15800858  -0.18305674 -23.2899704  -10.80175686])

f:id:ksknw:20160424234153p:plain

    No.172 正解:9 出力:1 ([-18.31829453   8.41508961 -24.00697708 -18.74674988 -15.38268757
     -19.12284851 -21.66376686 -29.5259037   -9.97081757   5.56778383])

f:id:ksknw:20160424234202p:plain

    No.196 正解:9 出力:4 ([-25.27557564 -20.27404976 -33.58036041 -38.26721573  18.712677
     -17.12530899 -26.8935318  -34.70022964 -20.10196686  11.45114803])

f:id:ksknw:20160424234210p:plain

   No.231 正解:5 出力:8 ([ -3.76606822 -27.21845436 -40.48172379 -41.74909592 -20.90390015
     -27.25065613   8.70675373 -39.89279938  28.23112106 -19.23669815])

f:id:ksknw:20160424234219p:plain

    No.247 正解:2 出力:4 ([-15.3788166   -8.5814476  -14.29022121 -16.61594963   5.68172646
      -8.83567333   0.14922404 -22.43881035 -17.46813393 -14.75987816])

f:id:ksknw:20160424234228p:plain

mnistデータとFaxOCRデータの違い

データを公開された方の本来の目的からすると、パラメータチューニングとかして性能を上げたほうがいいのかもしれないけど、正直そっちにはあまり興味がない。Kaggleガチ勢の方がこういうの とか出してくれているので、参考にするといいのかもしれない。

個人的に気になったのは今回のデータとmnistの違い。 見た目は同じような手書き文字なのに、t-sneで可視化すると明らかに分離している。 Faxのデータはどうも線の細さを統一したり、回転させたりと前処理を結構しているらしいので、それが影響しているのかなと思った。 なので、生データを可視化してみる。FaxOCRの画像サイズはまちまちだったので、ImageMagickで28x28にリサイズした。アスペクト比は保存していないのでややまずい。

import cv2 as cv
import glob
mnist = fetch_mldata('MNIST original', data_home=".")
X = mnist.data
y = mnist.target
X = X.astype(np.float32)
y = y.astype(np.int32)

X /= X.max()
X_train = X
y_train = y

X_test = []
y_test = []
for img_file in glob.glob("./data/mustread/28/*.png"):
    y_test.append(int(img_file[16]))
    X_test.append(255 - cv.imread(img_file, flags=0))
    
data_row = np.r_[X_train[indecies], np.array(X_test).reshape(len(X_test),-1)]
from sklearn.manifold import TSNE
model = TSNE(n_components=2)

tsned_row = model.fit_transform(data_row)
label = np.r_[["b" for i in range(1000)], ["r" for i in X_test]]
plt.figure(figsize=(30,30))
plt.scatter(tsned_row[:,0], tsned_row[:,1], c=label, linewidths=0)
plt.show()

f:id:ksknw:20160424234255p:plain

なんでや…

ミス訂正(2016/4/26)

FaxOCRのほうが0~1に補正されていなかった。正しいt-SNEの結果は以下。 f:id:ksknw:20160426233426p:plain これなら、線が細いやつが固まっていると思えば、まだありえる(?)

一旦終わり

正直なにがダメなのかよくわかってない。というかバグなんじゃないのか、僕のコードがなにかやらかしてるんじゃないのか。 見た目おんなじに見えるんだけど、何か本当に違うのか… 誰かバグとか根本的な間違いとか見つけたら教えてください…

ミス訂正(2016/4/26)

コードにミスがあったので、元画像データはmnistの一部がある空間にありそうだとわかった。時間があるときにもうちょっとちゃんとやって追記します。

続き

ksknw.hatenablog.com

参考