1次元ガウス分布の測地線と双対測地線のプロット

はじめに

情報幾何学の新展開」という本を読んでいる. まだ序盤しか読めてない上に,あまり理解できていないが,自分の理解のために,例として1次元ガウス分布を対象として,以下の導出とプロットをやる.

  • 指数型分布族の標準形および双対座標系
  • ポテンシャル関数,双対ポテンシャル関数
  • 測地線,双対測地線

かなり天下り的にやっている部分が多いので,主にプロットしただけという感もある.

import numpy as np
import pylab as plt
from numpy import pi as π
from numpy import log
from numpy import exp
from mpl_toolkits.mplot3d import Axes3D
e = exp(1)
%matplotlib inline

座標系のプロット

指数型分布の標準形は以下のようなもの.  p({\bf x, \theta}) = \exp\left(\sum_i \theta^{(i)} x_i - \psi({\bf \theta})\right) 1次元ガウス分布 p(x, \mu , \sigma) = \frac{1}{\sqrt{2 \pi \sigma^2} } \exp \left( \frac{(x-\mu)^2}{2 \sigma^2} \right) であるので,これをごちゃごちゃいじって,

 \theta^{(1)} = \frac{\mu}{\sigma^2}, \theta^{(2)} = \frac{1}{2\sigma^2}, x_{(1)}=x, x_{(2)}=-x^2とすると,ごちゃごちゃいじって,

 p(x, \theta) = \exp(\sum_i \theta^{(i)} x_i - \psi({\bf \theta})) - \left( \frac{(\theta^{(1)})^2}{4\theta^{(2)}} + \frac{1}{2}\log \pi -\frac{1}{2} \log \theta^{(2)}\right)

よって,  \psi(\theta) = \frac{(\theta^{(1)})^2}{4\theta^{(2)}} + \frac{1}{2}\log \pi -\frac{1}{2} \log \theta^{(2)} となる.

ここで, \psi(\theta)をプロットすると,以下のように,これが凸関数っぽくなっていることがわかる.(実際にも凸関数である)

def p2θ(μ, σ2):
    θ1 = μ/σ2
    θ2 = 1/(2*σ2)
    return θ1, θ2

def θ2p(θ1, θ2):
    μ = θ1/(2*θ2)
    σ2 = 1/(2*θ2)
    return μ, σ2

def ψ(θ1, θ2):
    return (θ1)**2/ (4*θ2) + 1/2 * log(π) + 1/2*log(θ2) 
def plot_lattice(t1, t2, c, alpha=1):
    for t1_i, t2_i in zip(t1, t2):
        plt.plot(t1_i, t2_i, c=c, alpha=alpha)
    for t1_i, t2_i in zip(t1.transpose(), t2.transpose()):
        plt.plot(t1_i, t2_i, c=c, alpha=alpha)
temp_θ1 = np.linspace(-1,1, 30)
temp_θ2 = np.linspace(1, 3, 30)

θ1, θ2 = np.meshgrid(temp_θ1, temp_θ2)
ψ_θ = ψ(θ1, θ2)
μ, σ2 = θ2p(θ1,θ2)
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(θ1, θ2, ψ_θ)
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")
ax.set_zlabel("$\psi(\\theta)$")
plt.show()

f:id:ksknw:20180814095351p:plain

%matplotlib inline
plt.figure(figsize=(8,4))

plt.subplot(121)
plot_lattice(μ, σ2, "C0", alpha=0.3)
cont = plt.contour(μ, σ2, ψ_θ)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")

plt.subplot(122)
plot_lattice(θ1, θ2, "C1", alpha=0.3)
cont = plt.contour(θ1, θ2, ψ_θ)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095418p:plain

ここで \psi(\theta)微分 {\bf \eta} = \nabla \psi (\theta)を考えると,  \eta^{(1)} = \frac{\theta^{(1)}}{2\theta^{(2)}}, \eta^{(2)} = -\frac{(\theta^{(1)})^2}{(\theta^{(2)})^2}-\frac{1}{2\theta^{(2)}}

凸関数の微分は元の座標系に対して1つ求まり,また,同じものはない. このため, \etaを座標系として使うこともできる.これを双対座標という. これをプロットすると以下.

def θ2η(θ1, θ2):
    η1 = θ1/(2*θ2)
    η2 = -θ1**2/(2*θ2)**2 - 1/(2*θ2)
    return η1, η2

def η2θ(η1, η2 ):
    θ1 = -η1/(η2+η1**2)
    θ2 = -1/2 * 1/(η2+η1**2)
    return θ1, θ2
μ, σ2 = θ2p(θ1,θ2)
η1,η2 = θ2η(θ1, θ2)
plt.figure(figsize=(12,4))
plt.subplot(131)
plot_lattice(μ, σ2, "C0")
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")
plt.subplot(132)
plot_lattice(θ1, θ2, "C1")
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")
plt.subplot(133)
plot_lattice(η1, η2, "C2")
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095434p:plain

この座標系でも \psi(\theta)をプロットすると以下.

μ, σ2 = θ2p(θ1,θ2)
η1,η2 = θ2η(θ1, θ2)

plt.figure(figsize=(12,4))

plt.subplot(131)
plot_lattice(μ, σ2, "C0", alpha=0.3)
cont = plt.contour(μ, σ2, ψ_θ)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")

plt.subplot(132)
plot_lattice(θ1, θ2, "C1", alpha=0.3)
cont = plt.contour(θ1, θ2, ψ_θ)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")

plt.subplot(133)
plot_lattice(η1, η2, "C2", alpha=0.3)
cont = plt.contour(η1, η2, ψ_θ)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095446p:plain

ここで,双対ポテンシャル関数は  \phi(\eta) = \theta \cdot \eta - \psi \left( \theta(\eta)\right )で与えられるので(?),  \phi(\eta) = -\frac{1}{2} \log(2\pi e) -\frac{1}{2} \log{ -(\eta^{(1)})^2 + \eta^{(2)}} これをプロットして以下.

これも双対ポテンシャル関数も凸関数であるので,その微分を座標系にすることもできる. これを求めると,元の \theta座標系になる.

def ϕ(η1, η2):
    return -1/2*log(2*π*e) - 1/2 * log(-(η1**2 + η2))
temp_η1 = np.linspace(-0.4, 0.4, 30)
temp_η2 = np.linspace(-1, -0.2, 30)

η1, η2 = np.meshgrid(temp_η1, temp_η2)


θ1,θ2 = η2θ(η1, η2)
μ, σ2 = θ2p(θ1,θ2)

ϕ_η = ϕ(η1, η2)
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_wireframe(η1, η2, ϕ_η)
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
ax.set_zlabel("$\phi(\eta)$")
plt.show()

f:id:ksknw:20180814095501p:plain

%matplotlib inline

plt.figure(figsize=(12,4))

plt.subplot(131)
plot_lattice(μ, σ2, "C0", alpha=0.3)
cont = plt.contour(μ, σ2, ϕ_η)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")

plt.subplot(132)
plot_lattice(θ1, θ2, "C1", alpha=0.3)
cont = plt.contour(θ1, θ2, ϕ_η)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")

plt.subplot(133)
plot_lattice(η1, η2, "C2", alpha=0.3)
cont = plt.contour(η1, η2, ϕ_η)
cont.clabel(fmt='%1.1f', fontsize=14)
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095517p:plain

測地線と双対測地線

測地線と双対測地線は,それぞれの座標系で直線として与えられる. つまり測地線は  \theta(t) = t \theta_1 + (1-t)\theta_2, 双対測地線は  \eta(t) = t \eta_1 + (1-t)\eta_2, である.

対応する点  \theta \etaについて,それぞれの測地線をプロットすると以下.

θ_1 = np.array([-1, 1])
θ_2 = np.array([1, 3])

list_t = np.arange(0, 1, 0.01)
line_θ = np.transpose([θ_1*t + θ_2*(1-t) for t in list_t])
line_p1 = θ2p(*line_θ)
line_η = θ2η(*line_θ)

temp_θ1 = np.linspace(-1,1, 30)
temp_θ2 = np.linspace(1, 3, 30)

θ1, θ2 = np.meshgrid(temp_θ1, temp_θ2)
μ, σ2 = θ2p(θ1, θ2)
η1,η2 = θ2η(θ1,θ2)


plt.figure(figsize=(12,4))

plt.subplot(131)
plot_lattice(μ, σ2, "C0", alpha=0.3)
plt.plot(line_p1[0], line_p1[1], "C4")
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")

plt.subplot(132)
plot_lattice(θ1, θ2, "C1", alpha=0.3)
plt.plot(line_θ[0], line_θ[1], "C4")
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")

plt.subplot(133)
plot_lattice(η1, η2, "C2", alpha=0.3)
plt.plot(line_η[0], line_η[1], "C4")
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095531p:plain

η_1 = np.array([-0.5, -3/4])
η_2 = np.array([ 1/6, -7/36])

list_t = np.arange(0, 1, 0.01)
line_η= np.transpose([η_1*t + η_2*(1-t) for t in list_t])
line_θ= η2θ(*line_η)
line_p2 = θ2p(*line_θ)
    
temp_η1 = np.linspace(-0.5, 0.5, 30)
temp_η2 = np.linspace(-1, -0.2, 30)

η1, η2 = np.meshgrid(temp_η1, temp_η2)
θ1,θ2 = η2θ(η1,η2)
μ, σ2 = θ2p(θ1, θ2)

    
plt.figure(figsize=(12,4))

plt.subplot(131)
plot_lattice(μ, σ2, "C0", alpha=0.3)
plt.plot(line_p2[0], line_p2[1], "C3")
plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")

plt.subplot(132)
plot_lattice(θ1, θ2, "C1", alpha=0.3)
plt.plot(line_θ[0], line_θ[1], "C3")
plt.xlabel("$\\theta^{(1)}$")
plt.ylabel("$\\theta^{(2)}$")

plt.subplot(133)
plot_lattice(η1, η2, "C2", alpha=0.3)
plt.plot(line_η[0], line_η[1], "C3")
plt.xlabel("$\eta^{(1)}$")
plt.ylabel("$\eta^{(2)}$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095542p:plain

これらを元のパラメータ空間でプロットすると以下のようになる.

temp_μ = np.linspace(-0.5,0.5, 30)
temp_σ2 = np.linspace(0.01, 1.0, 30)

μ, σ2 = np.meshgrid(temp_μ, temp_σ2)

plot_lattice(μ, σ2, "C0", alpha=0.3)
plt.plot(line_p1[0], line_p1[1], "C4")
plt.plot(line_p2[0], line_p2[1], "C3")

plt.text(-0.2, 0.5,"Dual geodesic line", size=14, color="C3")
plt.text(-0.4, 0.2,"Geodesic line", size=14, color="C4")

plt.xlabel("$\mu$")
plt.ylabel("$\sigma^2$")
plt.tight_layout()
plt.show()

f:id:ksknw:20180814095554p:plain

おわりに

とりあえず双対測地線のプロットまでをやった. 理解できてない部分もあるので,何か間違っているかもしれない.

この本は個人的には読みやすく感じるが,数学に詳しい人に聞くと, 先に接続とか平坦とかをやってから,測地線の方程式を求めたりしたほうがいいらしい (微分幾何を専攻したい人生だった). 実際天下り的に感じる部分もあるので,もうちょっと読めたら,ちゃんとそっちの順番で理解していきたいと思う.

参考

情報幾何学の新展開