はじめに
「情報幾何学の新展開」という本を読んでいる. まだ序盤しか読めてない上に,あまり理解できていないが,自分の理解のために,例として1次元ガウス分布を対象として,以下の導出とプロットをやる.
- 指数型分布族の標準形および双対座標系
- ポテンシャル関数,双対ポテンシャル関数
- 測地線,双対測地線
かなり天下り的にやっている部分が多いので,主にプロットしただけという感もある.
import numpy as np import pylab as plt from numpy import pi as π from numpy import log from numpy import exp from mpl_toolkits.mplot3d import Axes3D e = exp(1) %matplotlib inline
座標系のプロット
指数型分布の標準形は以下のようなもの. 1次元ガウス分布は であるので,これをごちゃごちゃいじって,
とすると,ごちゃごちゃいじって,
よって, となる.
ここで,をプロットすると,以下のように,これが凸関数っぽくなっていることがわかる.(実際にも凸関数である)
def p2θ(μ, σ2): θ1 = μ/σ2 θ2 = 1/(2*σ2) return θ1, θ2 def θ2p(θ1, θ2): μ = θ1/(2*θ2) σ2 = 1/(2*θ2) return μ, σ2 def ψ(θ1, θ2): return (θ1)**2/ (4*θ2) + 1/2 * log(π) + 1/2*log(θ2)
def plot_lattice(t1, t2, c, alpha=1): for t1_i, t2_i in zip(t1, t2): plt.plot(t1_i, t2_i, c=c, alpha=alpha) for t1_i, t2_i in zip(t1.transpose(), t2.transpose()): plt.plot(t1_i, t2_i, c=c, alpha=alpha)
temp_θ1 = np.linspace(-1,1, 30) temp_θ2 = np.linspace(1, 3, 30) θ1, θ2 = np.meshgrid(temp_θ1, temp_θ2) ψ_θ = ψ(θ1, θ2) μ, σ2 = θ2p(θ1,θ2)
fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.plot_wireframe(θ1, θ2, ψ_θ) plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") ax.set_zlabel("$\psi(\\theta)$") plt.show()
%matplotlib inline plt.figure(figsize=(8,4)) plt.subplot(121) plot_lattice(μ, σ2, "C0", alpha=0.3) cont = plt.contour(μ, σ2, ψ_θ) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(122) plot_lattice(θ1, θ2, "C1", alpha=0.3) cont = plt.contour(θ1, θ2, ψ_θ) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.tight_layout() plt.show()
ここでの微分を考えると,
凸関数の微分は元の座標系に対して1つ求まり,また,同じものはない. このため,を座標系として使うこともできる.これを双対座標という. これをプロットすると以下.
def θ2η(θ1, θ2): η1 = θ1/(2*θ2) η2 = -θ1**2/(2*θ2)**2 - 1/(2*θ2) return η1, η2 def η2θ(η1, η2 ): θ1 = -η1/(η2+η1**2) θ2 = -1/2 * 1/(η2+η1**2) return θ1, θ2
μ, σ2 = θ2p(θ1,θ2) η1,η2 = θ2η(θ1, θ2)
plt.figure(figsize=(12,4)) plt.subplot(131) plot_lattice(μ, σ2, "C0") plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(132) plot_lattice(θ1, θ2, "C1") plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.subplot(133) plot_lattice(η1, η2, "C2") plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") plt.tight_layout() plt.show()
この座標系でもをプロットすると以下.
μ, σ2 = θ2p(θ1,θ2) η1,η2 = θ2η(θ1, θ2) plt.figure(figsize=(12,4)) plt.subplot(131) plot_lattice(μ, σ2, "C0", alpha=0.3) cont = plt.contour(μ, σ2, ψ_θ) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(132) plot_lattice(θ1, θ2, "C1", alpha=0.3) cont = plt.contour(θ1, θ2, ψ_θ) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.subplot(133) plot_lattice(η1, η2, "C2", alpha=0.3) cont = plt.contour(η1, η2, ψ_θ) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") plt.tight_layout() plt.show()
ここで,双対ポテンシャル関数は で与えられるので(?), これをプロットして以下.
これも双対ポテンシャル関数も凸関数であるので,その微分を座標系にすることもできる. これを求めると,元の座標系になる.
def ϕ(η1, η2): return -1/2*log(2*π*e) - 1/2 * log(-(η1**2 + η2))
temp_η1 = np.linspace(-0.4, 0.4, 30) temp_η2 = np.linspace(-1, -0.2, 30) η1, η2 = np.meshgrid(temp_η1, temp_η2) θ1,θ2 = η2θ(η1, η2) μ, σ2 = θ2p(θ1,θ2) ϕ_η = ϕ(η1, η2)
fig = plt.figure() ax = fig.add_subplot(111, projection="3d") ax.plot_wireframe(η1, η2, ϕ_η) plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") ax.set_zlabel("$\phi(\eta)$") plt.show()
%matplotlib inline plt.figure(figsize=(12,4)) plt.subplot(131) plot_lattice(μ, σ2, "C0", alpha=0.3) cont = plt.contour(μ, σ2, ϕ_η) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(132) plot_lattice(θ1, θ2, "C1", alpha=0.3) cont = plt.contour(θ1, θ2, ϕ_η) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.subplot(133) plot_lattice(η1, η2, "C2", alpha=0.3) cont = plt.contour(η1, η2, ϕ_η) cont.clabel(fmt='%1.1f', fontsize=14) plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") plt.tight_layout() plt.show()
測地線と双対測地線
測地線と双対測地線は,それぞれの座標系で直線として与えられる. つまり測地線は , 双対測地線は , である.
対応する点 とについて,それぞれの測地線をプロットすると以下.
θ_1 = np.array([-1, 1]) θ_2 = np.array([1, 3]) list_t = np.arange(0, 1, 0.01) line_θ = np.transpose([θ_1*t + θ_2*(1-t) for t in list_t]) line_p1 = θ2p(*line_θ) line_η = θ2η(*line_θ) temp_θ1 = np.linspace(-1,1, 30) temp_θ2 = np.linspace(1, 3, 30) θ1, θ2 = np.meshgrid(temp_θ1, temp_θ2) μ, σ2 = θ2p(θ1, θ2) η1,η2 = θ2η(θ1,θ2) plt.figure(figsize=(12,4)) plt.subplot(131) plot_lattice(μ, σ2, "C0", alpha=0.3) plt.plot(line_p1[0], line_p1[1], "C4") plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(132) plot_lattice(θ1, θ2, "C1", alpha=0.3) plt.plot(line_θ[0], line_θ[1], "C4") plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.subplot(133) plot_lattice(η1, η2, "C2", alpha=0.3) plt.plot(line_η[0], line_η[1], "C4") plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") plt.tight_layout() plt.show()
η_1 = np.array([-0.5, -3/4]) η_2 = np.array([ 1/6, -7/36]) list_t = np.arange(0, 1, 0.01) line_η= np.transpose([η_1*t + η_2*(1-t) for t in list_t]) line_θ= η2θ(*line_η) line_p2 = θ2p(*line_θ) temp_η1 = np.linspace(-0.5, 0.5, 30) temp_η2 = np.linspace(-1, -0.2, 30) η1, η2 = np.meshgrid(temp_η1, temp_η2) θ1,θ2 = η2θ(η1,η2) μ, σ2 = θ2p(θ1, θ2) plt.figure(figsize=(12,4)) plt.subplot(131) plot_lattice(μ, σ2, "C0", alpha=0.3) plt.plot(line_p2[0], line_p2[1], "C3") plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.subplot(132) plot_lattice(θ1, θ2, "C1", alpha=0.3) plt.plot(line_θ[0], line_θ[1], "C3") plt.xlabel("$\\theta^{(1)}$") plt.ylabel("$\\theta^{(2)}$") plt.subplot(133) plot_lattice(η1, η2, "C2", alpha=0.3) plt.plot(line_η[0], line_η[1], "C3") plt.xlabel("$\eta^{(1)}$") plt.ylabel("$\eta^{(2)}$") plt.tight_layout() plt.show()
これらを元のパラメータ空間でプロットすると以下のようになる.
temp_μ = np.linspace(-0.5,0.5, 30) temp_σ2 = np.linspace(0.01, 1.0, 30) μ, σ2 = np.meshgrid(temp_μ, temp_σ2) plot_lattice(μ, σ2, "C0", alpha=0.3) plt.plot(line_p1[0], line_p1[1], "C4") plt.plot(line_p2[0], line_p2[1], "C3") plt.text(-0.2, 0.5,"Dual geodesic line", size=14, color="C3") plt.text(-0.4, 0.2,"Geodesic line", size=14, color="C4") plt.xlabel("$\mu$") plt.ylabel("$\sigma^2$") plt.tight_layout() plt.show()
おわりに
とりあえず双対測地線のプロットまでをやった. 理解できてない部分もあるので,何か間違っているかもしれない.
この本は個人的には読みやすく感じるが,数学に詳しい人に聞くと, 先に接続とか平坦とかをやってから,測地線の方程式を求めたりしたほうがいいらしい (微分幾何を専攻したい人生だった). 実際天下り的に感じる部分もあるので,もうちょっと読めたら,ちゃんとそっちの順番で理解していきたいと思う.