pytorchでgmmの最尤推定

はじめに

今まではKerasを使っていたけど、最近になってpytorchを覚えようとしている。 “Define by Run"と"Define and Run"の違いとかはよくわかっていないのでそのへんは適当。

普通にtutorialだけやっていると、 “なんとかネットワークは作れるけど、自分が考えた新しい層を追加できない” ということになりそうだったので、ネットにあまり情報のなかったgmmを勾配法(最尤推定)で解くプログラムを作って、pytorchを理解することにした。

gaussian mixture model

適当にデータを作る

%matplotlib inline
import pylab as plt
import seaborn as sns
sns.set_style("white")
from scipy.stats import norm

import numpy as np
import torch
from torch.autograd import Variable


K = 2
nb_data = 2000
nb_steps = 2000

true_μ_K = [-3, 3]
true_σ_K = [1, 1]
true_π_K = [0.5, 0.5]

np_data = []
for μₖ, σₖ, πₖ in zip(true_μ_K, true_σ_K, true_π_K):
    for i in range(int(nb_data * πₖ)):
        np_data.append(np.random.normal(μₖ, σₖ))
np_data = np.array(np_data)
def gmm_plot(list_mean, list_std, list_pi, **kwargs):
    x = np.linspace(-10,10,500)
    y = np.sum([norm.pdf(x, mean, np.abs(std))*pi 
                for mean,std,pi in zip(list_mean, list_std, list_pi)], axis=0)
    return plt.plot(x,y, **kwargs)
gmm_plot(true_μ_K, true_σ_K, true_π_K)

f:id:ksknw:20170624233221p:plain

最尤推定

勾配法でパラメータを推定するための誤差関数として、今回は単純な尤度を使った。

gmmの尤度は { \displaystyle
p(x|\theta) = \prod_n  \prod_k \pi_k \mathcal{N}(x | \mu_k, \sigma_k)
} 。これと、πの合計が1になるように、適当に制約を加えて、以下のように誤差関数を定義した。

def get_normal_lpdf(x_N, μ, σ):
    μ_N = μ.expand(x_N.size())
    σ_N = σ.expand(x_N.size())
    return -0.5 * torch.log(2 * np.pi * σ_N ** 2) - 0.5 * (x_N - μ_N)**2 / σ_N ** 2

def get_loss(normal_lpdf_K_N, π_K):
    gmm_lpdf_N = 0
    for normal_lpdfₖ_N, πₖ in zip(normal_lpdf_K_N, π_K):
        πₖ_N = πₖ.expand(normal_lpdfₖ_N.size())
        gmm_lpdf_N += (torch.exp(normal_lpdfₖ_N) * πₖ_N) # TODO logsumexpを実装したほうがいいかも
    gmm_lpdf = torch.mean(torch.log(gmm_lpdf_N))
    
    Σπ = torch.sum(π_K)
    gmm_lpdf -= torch.abs(1 - Σπ) # 制約条件
    return -gmm_lpdf

pytorchではautograd.Variableで変数を定義しておくと、勝手に微分を計算してくれるらしい。 ので、以下のようにデータと求めたいパラメータを定義する。

x_N = Variable(torch.from_numpy(np_data), requires_grad=False).float()
lr = 0.05
μ_K = Variable(torch.randn(K), requires_grad=True)
σ_K = Variable(torch.randn(K)**2, requires_grad=True)
π_K = Variable(torch.abs(torch.randn(K)), requires_grad=True)

あとは以下のように誤差を伝搬させて、パラメータを更新する。 grad.zero_をしないといけないと知らなくて苦労した。

history_loss = []
history_μ_K = []
history_σ_K = []
history_π_K  = []

for i in range(nb_steps):
    normal_lpdf_K_N = []
    for k in range(K):
        normal_lpdf_K_N.append(get_normal_lpdf(x_N, μ_K[k], σ_K[k]))
    loss = get_loss(normal_lpdf_K_N, π_K)
    loss.backward()
    
    μ_K.data -= μ_K.grad.data * lr
    μ_K.grad.data.zero_()
    
    σ_K.data -= σ_K.grad.data * lr
    σ_K.grad.data.zero_()
    
    π_K.data -= π_K.grad.data * lr
    π_K.grad.data.zero_()
    π_K.data = torch.abs(π_K.data)
    
    
    history_loss.extend(loss.data.tolist())
    history_μ_K.append(μ_K.data.tolist())
    history_σ_K.append(σ_K.data.tolist())
    history_π_K.append(π_K.data.tolist())
    

うまく収束してくれた。

gmm_plot(μ_K.data, σ_K.data, π_K.data)
gmm_plot(true_μ_K, true_σ_K, true_π_K)
plt.show()

f:id:ksknw:20170624233252p:plain

plt.plot(history_loss)
plt.show()

f:id:ksknw:20170624233305p:plain

収束のアニメーションをかいてみる

import matplotlib.animation as animation

plts = []
fig = plt.figure()

for μ_K, σ_K, π_K in zip(history_μ_K[::10],
                         history_σ_K[::10],
                         history_π_K[::10]):   
    plts.append(gmm_plot(μ_K, σ_K, π_K, c="b"))

ani = animation.ArtistAnimation(fig, plts, interval=100)
ani.save('anim.gif', writer="imagemagick")
ani.save('anim.mp4', writer="ffmpeg")
plt.show()

f:id:ksknw:20170624233331g:plain

これはうまく収束した結果だけど、何度か実行すると以下のような微妙な結果に収束することも多かった。 (収束したあとに"脈打ってる"のはなんだ…)

f:id:ksknw:20170624233734p:plain

まとめ

pytorchの使い方を覚えるために、gmmをやった。 あまり安定していないけど、動くものができた。 なんとなく使い方はわかってきたけど、"勝手に中身が書き換えられている"という印象をもってしまう部分がある。

やる気があればMAP推定、SGDもしくはHMCに続くかも。