pythonでNUTSの実装

はじめに

前回 ハミルトニアンモンテカルロ法の実装をやった.

今回は No U-Turn Sampler (NUTS)の実装をやる. 論文を参考にした.

コードはここにもある

github.com

NUTS

ハミルトニアンモンテカルロ (HMC)はパラメータの勾配を利用して, 効率的にMCMCサンプリングを行うことができる手法だった.

HMCの問題点は2つ.

  • 更新ステップ数 { \displaystyle L}を適切に指定しなければいけない.
  • 更新の大きさ { \displaystyle \epsilon} を適切に指定しなければいけない.

NUTSは更新ステップ数{ \displaystyle L}を自動的に決定する手法である. 論文内では{ \displaystyle \epsilon}はdual-averaging (Nesterov 2009)を用いて決定するが,今回は決め打ちにする.

更新ステップ数L

ハミルトニアンモンテカルロでは,正規分布によって発生させた運動量を与えて, { \displaystyle L}ステップの間,点を動かす. 予め決められた{ \displaystyle L}ステップの間,点を動かすので,例えば谷にハマった時などガタガタして無駄な計算をしてしまう.

NUTSでは「Uターンしたら点を動かすのをやめる」という規則でこの無駄な計算をなくす. ただし,単純にUターンしたときに中断したら詳細釣り合い条件を満たさなくなるので,少し工夫する.

具体的には,だいたい

  1. 運動量{ \displaystyle r_0}をランダムに決める.
  2. 時間の向き{ \displaystyle v}を{-1, 1}からランダムに選ぶ.
  3. 選んだ時間の向きの方向に{ \displaystyle 2^j}回,点を移動させる(点の軌跡は全て記憶しておく)
  4. どこかでUターンしていないかを確認する
  5. Uターンしていたら6へ,それ以外は{ \displaystyle j+1}して2へ戻る
  6. これまでに計算した軌跡からランダムに1点選んで,サンプリング結果に加えて,1に戻る

という感じ.

実装

以下のようなデータの平均と分散をサンプリングする.

import numpy as np
from numpy import exp
from copy import deepcopy
import pylab as plt
import seaborn as sns
from scipy.stats import norm, gamma
from tqdm import tqdm
from numpy import random
from matplotlib.animation import FuncAnimation

sns.set_style("white")

true_μ = 3
true_σ = 1
nb_data = 1000

x = np.random.normal(true_μ, true_σ, nb_data)

print(x.mean(), x.std())
sns.kdeplot(x)
plt.show()
2.99882646635 1.02180321215

f:id:ksknw:20170806182124p:plain

NUTSのプログラムは以下. 論文のNaive-NUTSの実装をする. Leapfrogで点を移動させるなどはハミルトニアンモンテカルロと同じ.

def log_dh(μ, σ):
    return np.array([-np.sum(x - μ) / σ**2,
                     len(x) / (2 * σ**2) - np.sum((x - μ)**2) / (2 * σ**4)])

def H(θₜ, p):
    return -norm_lpdf(θₜ[0], θₜ[1]) + 0.5 * np.dot(p, p)

def Leapfrog(x, θ, r, ε):
    θ_d = deepcopy(θ)
    r_d = deepcopy(r)
    r_d -= 0.5 * ε * log_dh(θ_d[0], θ_d[1])
    θ_d[0] = θ_d[0] + ε * r_d[0]
    θ_d[1] = θ_d[1] + ε * r_d[1]
    r_d -= 0.5 * ε * log_dh(θ_d[0], θ_d[1])
    return θ_d, r_d
norm_lpdf = lambda μ, σ: np.sum(norm.logpdf(x, μ, σ))
gamma_lpdf = lambda a: np.sum(gamma.logpdf(x, a))

Δ_max = 1000
ε = 0.05
L = norm_lpdf
M = 100

θ0 = np.array([random.randn(), random.gamma(1)])
list_θₘ = [θ0]

{ \displaystyle 2^j} 回の移動は再帰で実装される.

def BuildTree(θ, r, u, v, j, ε):
    if j == 0:
        θd, rd = Leapfrog(x, θ, r, v * ε)
        if np.log(u) <= (L(*θd) - 0.5 * np.dot(rd, rd)):
            Cd_ = [[θd, rd]]
        else:
            Cd_ = []
        sd = int(np.log(u) < (Δ_max + L(*θd) - 0.5 * np.dot(rd, rd)))
        return θd, rd, θd, rd, Cd_, sd
    else:
        θ_minus, r_minus, θ_plus, r_plus, Cd_, sd = BuildTree(θ, r, u, v, j - 1, ε)
        if v == -1:
            θ_minus, r_minus, _, _, Cdd_, sdd = BuildTree(θ_minus, r_minus, u, v, j - 1, ε)
        else:
            _, _, θ_plus, r_plus, Cdd_, sdd = BuildTree(θ_plus, r_plus, u, v, j - 1, ε)
        sd = sdd * sd * int((np.dot(θ_plus - θ_minus, r_minus) >= 0) and (np.dot(θ_plus - θ_minus, r_plus) >= 0))
        Cd_.extend(Cdd_)

        return θ_minus, r_minus, θ_plus, r_plus, Cd_, sd
hist_L = []
hist_C = []
for m in tqdm(range(M)):
    r0 = random.randn(2)
    u = random.uniform(0, exp(L(*list_θₘ[-1]) - 0.5 * np.dot(r0, r0)))

    θ_minus = deepcopy(list_θₘ[-1])
    θ_plus = deepcopy(list_θₘ[-1])
    r_minus = deepcopy(r0)
    r_plus = deepcopy(r0)
    j = 0
    C = [[deepcopy(list_θₘ[-1]), deepcopy(r0)]]
    s = 1

    while s == 1:
        v = random.choice([-1, 1])
        if v == -1:
            θ_minus, r_minus, _, _, Cd, sd = BuildTree(θ_minus, r_minus, u, v, j, ε)
        else:
            _, _, θ_plus, r_plus, Cd, sd = BuildTree(θ_plus, r_plus, u, v, j, ε)

        if sd == 1:
            C.extend(Cd)
        s = sd * int((np.dot(θ_plus - θ_minus, r_minus) >= 0) and (np.dot(θ_plus - θ_minus, r_plus) >= 0))
        j += 1

    index = random.choice(list(range(len(C))))
    list_θₘ.append(C[index][0])

    hist_L.append(L(C[index][0][0], C[index][0][1]))
    hist_C.append(C)
100%|██████████| 100/100 [00:00<00:00, 760.56it/s]
def plot(list_θₘ, hist_C):
    fig = plt.figure()
    ax = fig.add_subplot(111)

    def update(i):
        fig.canvas.draw()
        ax.cla()
        j = i // 3
        if (i % 3) == 0:
            ax.scatter(np.array(hist_C[j])[:, 0, 0], np.array(hist_C[j])[:, 0, 1], linewidth=0, marker=".")
            ax.plot(list_θₘ[:(j), 0], list_θₘ[:(j), 1], c="gray", linewidth=0.3, alpha=0.4)
            ax.scatter(list_θₘ[:(j), 0], list_θₘ[:(j), 1], c="g", linewidth=0, marker=".")

        elif (i % 3) == 1:
            ax.scatter(np.array(hist_C[j])[:, 0, 0], np.array(hist_C[j])[:, 0, 1], linewidth=0, marker=".")
            ax.plot(list_θₘ[:(j + 1), 0], list_θₘ[:(j + 1), 1], c="gray", linewidth=0.3, alpha=0.4)
            ax.scatter(list_θₘ[:(j + 1), 0], list_θₘ[:(j + 1), 1], c="g", linewidth=0, marker=".")
        else:
            ax.scatter(np.array(hist_C[j])[:, 0, 0], np.array(hist_C[j])[:, 0, 1], linewidth=0, marker=".", c="w")
            ax.plot(list_θₘ[:(j + 1), 0], list_θₘ[:(j + 1), 1], c="gray", linewidth=0.3, alpha=0.4)
            ax.scatter(list_θₘ[:(j + 1), 0], list_θₘ[:(j + 1), 1], c="g", linewidth=0, marker=".")

        plot_lim = 30
        if(j) > plot_lim:
            temp_list = np.array([[np.min(np.array(C)[:, 0], axis=0),
                                   np.max(np.array(C)[:, 0], axis=0)] for C in hist_C[j - plot_lim: j]])

            temp_xlim = [np.min(temp_list[:, :, 0]), np.max(temp_list[:, :, 0])]
            xlim_range = temp_xlim[1] - temp_xlim[0]
            temp_ylim = [np.min(temp_list[:, :, 1]), np.max(temp_list[:, :, 1])]
            ylim_range = temp_ylim[1] - temp_ylim[0]
            plt.xlim(temp_xlim[0] - xlim_range * 0.1, temp_xlim[1] + xlim_range * 0.1)
            plt.ylim(temp_ylim[0] - ylim_range * 0.1, temp_ylim[1] + ylim_range * 0.1)

    ani = FuncAnimation(fig, update, frames=len(hist_C) * 3 - 2)
    ani.save("temp.gif", writer="imagemagick", fps=3)
    
plot(list_θₘ, hist_C)

たぶんできていると思う きれいに書き直すつもりだったけど、面倒だったからやめた

「2017年3月14日をもって動画のアップロード機能は終了いたしました」 じゃないんだよなぁ

参考