Tree-sliced Wasserstein距離の実装と理解

はじめに

最適輸送の理論とアルゴリズムを買って読みました。 Computational Optimal Transportも7割ぐらい読んでたのですが、やっぱり日本語でこういう本があるといいですね。 本の内容のなかでtree-sliced Wassersteinが気になったので、理解のために実装してみます。 以下では、2つの2次元点群間のtree-sliced Wasserstein距離を考えます。 (証明などを含む細かい内容についてはここでは触れません。また、実装は理解のためのものであり遅いです)

Tree-sliced Wasserstein距離

Wasserstein距離は確率分布などの距離を測れるいい感じの距離ですが、計算が重いという問題があります。 Wasserstein距離を高速に計算するためにいくつかの手法が提案されています。 そのうちの1つがtree-sliced Wasserstein距離です。 Tree-sliced Wasserstein距離は木上の最適輸送距離が木のノード数に対して線形時間で解けることを利用することで高速に計算することができます。 点群データに対するtree-sliced Wasserstein距離を計算するためには、はじめに2つの点群データを1つの木で表現する必要があります。 この変換の方法によって、点群上で直接Wasserstein距離を求める場合とtree-sliced Wasserstein距離は異なる値になります。

はじめに、以下のような4点ずつの点群(X, Y)間のtree-sliced Wasserstein距離を考えます。

import ot
import numpy as np
import pylab as plt
from sklearn.cluster import AgglomerativeClustering
import anytree
from anytree import Node, RenderTree
from anytree.exporter import DotExporter
from IPython.display import Image
X = np.random.randn(4, 2)
Y = np.random.randn(4, 2)

plt.scatter(X[:,0], X[:,1], label="X", marker="o")
plt.scatter(Y[:,0], Y[:,1], label="Y", marker="o")
plt.legend()

Tree-sliced Wasserstein距離では、はじめにこの2つの点群を1つの木に変換します。 木に変換する方法はいくつか考えられますが、ここでは以下のように階層クラスタリングを行います。 今回用いたクラスタリング手法では、データ点はすべて葉ノードに含まれます。

XY = np.r_[X,Y]
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric="l1", linkage="average")
model.fit(XY)

得られたクラスタリング結果を木で表現すると以下のような感じになります。

nodes = []
for i,x_i in enumerate(X):
    nodes.append(Node("x_%d"%i))
for j,y_j in enumerate(Y):
    nodes.append(Node("y_%d"%j))
for node_i, children in enumerate(model.children_):
    nodes.append(Node("%d" % (node_i + len(XY)), children=[nodes[c] for c in children]))
root = nodes[-1]
DotExporter(root).to_picture("temp.png")
Image("temp.png")

この後のためにprintでも木を表示しておきます。

for pre, fill, node in RenderTree(root):
    print("%s(%s)" % (pre, node.name))
(14)
├── (11)
│   ├── (x_2)
│   └── (y_2)
└── (13)
    ├── (8)
    │   ├── (x_1)
    │   └── (y_0)
    └── (12)
        ├── (x_3)
        └── (10)
            ├── (y_3)
            └── (9)
                ├── (x_0)
                └── (y_1)

点群間のWasserstein距離を近似するためには、あるノードから別のノードまでの経路上の距離が対応する2つの点間の距離と近くなる必要があります。 ここでは簡単に、葉ノードの位置はもとの点群の点の位置、それ以外のノードの位置は、そのノードを根とする部分木に含まれる点の位置の平均としました。

nodes = []
for i,x_i in enumerate(X):
    nodes.append(Node("%d"%i, pos=x_i, nb_nodes=1))
for j,y_j in enumerate(Y):
    nodes.append(Node("%d"%j, pos=y_j, nb_nodes=1))
for node_i, children in enumerate(model.children_):
    nb_nodes = sum([nodes[c].nb_nodes for c in children])
    nodes.append(Node("%d" % (node_i + len(XY)), children=[nodes[c] for c in children], 
                      pos=np.sum([nodes[c].pos * nodes[c].nb_nodes for c in children], axis=0)/nb_nodes,
                      nb_nodes=nb_nodes
                     ))
root = nodes[-1]

2次元上にも木をプロットしてみます。緑の点がノードを表しています。

positions = np.array([node.pos for node in nodes])
plt.scatter(positions[:,0], positions[:,1], c="C2", marker=".", label="node")

plt.scatter(X[:,0], X[:,1], label="X", marker="o")
plt.scatter(Y[:,0], Y[:,1], label="Y", marker="o")

for node in nodes:
    for child in node.children:
        plt.plot(*zip(node.pos, child.pos), c="gray", lw=0.4)

plt.legend()

点群を木で表現することができました。次に、木上のWasserstein距離を計算します。 まず、点群の各点に対応するノード(今回は葉ノード)に質量($\mu, \nu$)を与えます。 ここでは同じ質量を各葉ノードに与えることにします。

nodes = []
for i,x_i in enumerate(X):
    nodes.append(Node("%d"%i, a_u=1/len(X), b_u=0, pos=x_i, nb_nodes=1))
for j,y_j in enumerate(Y):
    nodes.append(Node("%d"%j, a_u=0, b_u=1/len(Y), pos=y_j, nb_nodes=1))
for node_i, children in enumerate(model.children_):
    nb_nodes = sum([nodes[c].nb_nodes for c in children])
    nodes.append(Node("%d" % (node_i + len(XY)), children=[nodes[c] for c in children],
                      a_u=0, b_u=0, 
                      pos=np.sum([nodes[c].pos * nodes[c].nb_nodes for c in children], axis=0)/nb_nodes,
                      nb_nodes=nb_nodes
                     ))
root = nodes[-1]
for pre, fill, node in RenderTree(root):
    print("%s(%s) mu:%s nu:%s pos:%s" % (pre, node.name, node.a_u, node.b_u, node.pos))
(14) mu:0 nu:0 pos:[0.05533982 0.2128882 ]
├── (11) mu:0 nu:0 pos:[-1.00601231 -0.97207743]
│   ├── (2) mu:0.25 nu:0 pos:[-1.18749779 -0.59495726]
│   └── (2) mu:0 nu:0.25 pos:[-0.82452682 -1.3491976 ]
└── (13) mu:0 nu:0 pos:[0.40912386 0.60787674]
    ├── (8) mu:0 nu:0 pos:[-0.25630872  0.36544372]
    │   ├── (1) mu:0.25 nu:0 pos:[-0.26986513  0.52426891]
    │   └── (0) mu:0 nu:0.25 pos:[-0.24275231  0.20661853]
    └── (12) mu:0 nu:0 pos:[0.74184015 0.72909325]
        ├── (3) mu:0.25 nu:0 pos:[ 0.65140353 -0.11594999]
        └── (10) mu:0 nu:0 pos:[0.77198569 1.01077434]
            ├── (3) mu:0 nu:0.25 pos:[1.21127603 1.128238  ]
            └── (9) mu:0 nu:0 pos:[0.55234052 0.9520425 ]
                ├── (0) mu:0.25 nu:0 pos:[0.50610384 0.69303867]
                └── (1) mu:0 nu:0.25 pos:[0.5985772  1.21104634]

木$T$とその上の質量(測度)が与えられたとき、木上のWassserstein距離は以下のように求めることができます。 $$ {W}_{d_T} (\mu, \nu) = \sum_{v\in V\setminus {r}} d(v, q(v)) | \mu (\Gamma(v)) - \nu(\Gamma(v)) | $$ ここで、$\Gamma(v)$はノード$v$を根とする部分木を表す。また、$V$はノードの集合、$r$は木の根ノードを表す。 また、$d(v, q(v))$はノード$v$から親ノード$q(v)$までの距離を表しています。

この値はsumをとる順番を工夫することで、$|V|$に対して線形時間で計算することができます。 具体的には根から遠い順に、各ノードの重みを子ノードの重みの合計で更新していきます(本中p206、アルゴリズム5.4)。

depths = [n.depth for n in nodes]
s = 0

for node, depth in sorted(zip(nodes, depths), key=lambda x:x[1], reverse=True):
    node.a_u += sum([n.a_u for n in node.children])
    node.b_u += sum([n.b_u for n in node.children])
    
    if node.parent is not None:
        s += np.sum(np.abs(node.parent.pos - node.pos)) * abs(node.a_u - node.b_u)
print("tree-sliced Wasserstein distance:", s)
tree-sliced Wasserstein distance: 0.9691290075619952

ということで、tree-sliced Wasserstein距離を計算することができました。 各ノードの重みは以下のように更新されています。

for pre, fill, node in RenderTree(root):
    print("%s(%s) mu:%s nu:%s pos:%s" % (pre, node.name, node.a_u, node.b_u, node.pos))
(14) mu:1.0 nu:1.0 pos:[0.05533982 0.2128882 ]
├── (11) mu:0.25 nu:0.25 pos:[-1.00601231 -0.97207743]
│   ├── (2) mu:0.25 nu:0 pos:[-1.18749779 -0.59495726]
│   └── (2) mu:0 nu:0.25 pos:[-0.82452682 -1.3491976 ]
└── (13) mu:0.75 nu:0.75 pos:[0.40912386 0.60787674]
    ├── (8) mu:0.25 nu:0.25 pos:[-0.25630872  0.36544372]
    │   ├── (1) mu:0.25 nu:0 pos:[-0.26986513  0.52426891]
    │   └── (0) mu:0 nu:0.25 pos:[-0.24275231  0.20661853]
    └── (12) mu:0.5 nu:0.5 pos:[0.74184015 0.72909325]
        ├── (3) mu:0.25 nu:0 pos:[ 0.65140353 -0.11594999]
        └── (10) mu:0.25 nu:0.5 pos:[0.77198569 1.01077434]
            ├── (3) mu:0 nu:0.25 pos:[1.21127603 1.128238  ]
            └── (9) mu:0.25 nu:0.25 pos:[0.55234052 0.9520425 ]
                ├── (0) mu:0.25 nu:0 pos:[0.50610384 0.69303867]
                └── (1) mu:0 nu:0.25 pos:[0.5985772  1.21104634]

Wasserstein距離、sliced-Wsserstein距離との比較

tree-sliced Wasserstein距離が求められるようになったので、これを点群間の普通のWasserstein距離、および、同じくWasserstein距離を近似する方法の1つであるsliced-Wasserstein距離と比較します。

def calc_tree_Wass(X, Y):
    XY = np.r_[X,Y]
    model = AgglomerativeClustering(distance_threshold=0, n_clusters=None, metric="l1", linkage="average")
    model.fit(XY)
    
    nodes = []
    for i,x_i in enumerate(X):
        nodes.append(Node("%d"%i, a_u=1/len(X), b_u=0, pos=x_i, nb_nodes=1))
    for j,y_j in enumerate(Y):
        nodes.append(Node("%d"%j, a_u=0, b_u=1/len(Y), pos=y_j, nb_nodes=1))
    for node_i, (children, d) in enumerate(zip(model.children_, model.distances_)):
        nb_nodes = sum([nodes[c].nb_nodes for c in children])
        nodes.append(Node("%d" % (node_i + len(XY)), children=[nodes[c] for c in children],
                          a_u=0, b_u=0, 
                          pos=np.sum([nodes[c].pos * nodes[c].nb_nodes for c in children], axis=0)/nb_nodes,
                          nb_nodes=nb_nodes
                         ))
    root = nodes[-1]

    depths = [n.depth for n in nodes]
    s = 0
    for node, depth in sorted(zip(nodes, depths), key=lambda x:x[1], reverse=True):
        node.a_u += sum([n.a_u for n in node.children])
        node.b_u += sum([n.b_u for n in node.children])
        if node.parent is not None:
            s += np.sum(np.abs(node.parent.pos - node.pos)) * abs(node.a_u - node.b_u)
    return s
def calc_all(X, Y):
    tree_wass = calc_tree_Wass(X,Y)
    
    # 1-Wasserstein        
    M = np.sum(abs(X[:, None] - Y[None]), axis=2)
    wass = ot.emd2(np.ones(len(X))/len(X), 
            np.ones(len(Y))/len(Y),
            M)
    
    # sliced Wasserstein
    sliced_wass = ot.sliced.sliced_wasserstein_distance(X, Y, 
                                         np.ones(len(X))/len(X),
                                         np.ones(len(Y))/len(Y), 
                                         p=1)

    return wass, tree_wass, sliced_wass

適当な点群を2つ作って3つの距離をそれぞれ計算してみます。

distances_bias = []
for bias in np.linspace(0, 5, 10):
    for _ in range(10):
        X = np.random.randn(101, 2)
        Y = np.random.randn(100, 2)  + bias
        
        distances_bias.append(calc_all(X,Y))
distances_bias = np.array(distances_bias)
plt.scatter(X[:,0], X[:,1], label="X", marker=".")
plt.scatter(Y[:,0], Y[:,1], label="Y", marker=".")
plt.legend()

plt.scatter(distances_bias[:,0], distances_bias[:,1], marker=".", c="C2", label="tree-sliced")
plt.scatter(distances_bias[:,0], distances_bias[:,2], marker=".", c="C3", label="sliced")
plt.plot([0,10], [0,10], c="gray")

plt.legend()
plt.xlabel("1-Wasserstein distance")
plt.ylabel("(tree) sliced-Wasserstein distance")

横軸がWasserstein距離、縦軸がその他の2つの距離です。また、グレーの線がWasserstein距離を表しています。 図からわかるように、tree-sliced Wasserstein距離は常にWasserstein距離よりも大きな値になってしまいました。

本には書いてなかったので間違っているかもしれませんが、直感的にはtree-sliced Wasserstein距離とWasserstein距離が一致するためには、2つの葉ノードを結ぶ経路上の距離の合計が、もとの点群の対応する2つの点間の距離と一致する必要があるような気がします。一方で、今回行ったノードに対する座標の与え方では、三角不等式から経路上の距離は常にもとの点間の距離以上の値になってしまいます(いくつか上の世代の親ノードの座標を一旦経由してからもう一方のノードの座標に移動するので)。このため、tree-sliced Wasserstein距離のほうが値が大きくなってしまったのだと思われます。

別の傾向をもった点群に対しても3つの距離を求め、比較してみます。

distances_var = []
for var in np.linspace(0, 5, 10):
    for _ in range(10):
        X = np.random.randn(101, 2)
        Y = np.random.randn(100, 2) * var
        
        distances_var.append(calc_all(X,Y))
distances_var = np.array(distances_var)
plt.scatter(X[:,0], X[:,1], label="X", marker=".")
plt.scatter(Y[:,0], Y[:,1], label="Y", marker=".")
plt.legend()

plt.scatter(distances_bias[:,0], distances_bias[:,1], marker=".", c="C2", label="tree-sliced (bias)")
plt.scatter(distances_var[:,0], distances_var[:,1], marker=".", c="C4", label="tree-sliced (var)")

plt.scatter(distances_bias[:,0], distances_bias[:,2], marker=".", c="C3", label="sliced (bias)")
plt.scatter(distances_var[:,0], distances_var[:,2], marker=".", c="C5", label="sliced (var)")


plt.plot([0,10], [0,10], c="gray")
plt.xlabel("1-Wasserstein distance")
plt.ylabel("tree-sliced Wasserstein distance")
plt.legend()

sliced Wasserstein距離はデータの性質によらず、同じような傾向がありましたが、tree-sliced Wasserstein距離はデータの性質によって距離が変わってしまいました。階層クラスタリングのやり方が良くないことが原因かなと思いますが、今回はここまでです。 本によると、quadtreeを使ってクラスタリングすると近似度に理論保証があるらしいです。

まとめ

最適輸送の理論とアルゴリズムを読む中で気になったtree-sliced Wasserstein距離を理解のために実装してみました。 実装ミスってる可能性もありますが、何やってるかわかった気がします。 木の作り方に距離の性質がかなり依存しているので、別のクラスタリング手法なども試してみたいです

微分可能なソートの実装

はじめに

この論文を読んでおもしろいなと思ったので実装して試してみる。 ここではイントロは飛ばしてやり方の説明だけ。

論文の基本的なアイディアは1次元の最適輸送問題をSinkhornで解いて、得られた輸送計画からソートの結果を推定するという感じ

Preliminary

以下では微分可能なソートに使うSinkhornアルゴリズムと1次元の最適輸送問題についてそれぞれ説明する。

Sinkhornアルゴリズム

2つの点群(離散分布)の間の最適輸送について考える。 簡単のために点の数はともに$n$とし、各点の重みはすべて等しい($1/n$)とする。 このような2つの点群を$X = \{x_1, \ldots, x_n\}$と$Y = \{y_1, \ldots, y_n\}$と書く。 このとき、最適輸送は点同士の対応付けを、対応付けられた点間の距離の合計が最小になるように求める。 式で書くと以下。ただし、ground metric (点同士の距離)として二乗誤差$C$ ($C_{ij} = \|x_i - y_j\|^2_2$)を使う。

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt$$

ここで、$P$は対応付けを表す行列。$U$は対応付けの条件を表しており、$U=\{P\in [0,1]^{n\times n} | P \mathbf 1_n = 1/n, P^\top \mathbf 1_n = 1/n \}$である。例えば$x_1$と$y_2$が対応づけられる場合は$P_{12}=1/n$となる。この最適化問題線形計画問題なので単体法などで解くことができる。

単体法などは微分可能なアルゴリズムではないが、最適化問題を以下のように少し変更することで、微分可能なSinkhorn iterationという方法によって最適輸送問題を解くことができることが知られている。

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt - H(P)$$

ここで,$H(P)$は輸送行列$P$のエントロピーを表し,$H(P) = - \sum_{i,j} P_{ij} \log P_{ij}$である。 また、$\epsilon$はエントロピー正則化の強さを表す。 Sinkhornの具体的な計算については以前書いたものを参考。

ksknw.hatenablog.com

1次元の最適輸送問題

2つの点群がそれぞれ1次元上に存在するときを考える。 見やすさのために$y$軸方向に少しずらしてプロットする。青点が$x_i$、オレンジ点が$y_j$を表す ($x_i, y_j \in \mathbb R$)。

import numpy as np
import pylab as plt
import torch
import torch.nn as nn

torch.manual_seed(1234)

n = 5

X = torch.randn(n)
Y = torch.randn(n)

plt.scatter(X, torch.zeros(n))
plt.scatter(Y, torch.ones(n))
plt.show()

1次元のときの最適輸送問題は簡単に解ける。点のうち左側にあるもの同士を対応付け、次に左から2点目を対応付け、…と左から順に対応付ければ最適解を得ることができる。 これは2つの点群をそれぞれソートする操作に対応する。

P = torch.zeros(n, n)
P[torch.sort(X)[1], torch.sort(Y)[1]] = 1/n

def plot(X, Y, P):
    plt.scatter(X, torch.zeros(n))
    plt.scatter(Y, torch.ones(n))
    for i,p_i in enumerate(P):
        for j,p_ij in enumerate(p_i):
            plt.plot([X[i], Y[j]], [0, 1],
                     alpha=p_ij.numpy()*(n-1),
                     color="gray")
    plt.show()

plot(X, Y, P)

実際に点群をソートするだけで1次元の最適輸送問題を解くことができた。

1次元の最適輸送は多くの場合ソートによって解かれる。これはSinkhornよりもソートのほうが計算量が少なくて済むからである。 しかし、当然Sinkhornでも1次元の最適輸送問題を解くことができる。

class Sinkhorn(nn.Module):
    def __init__(self, epsilon, steps):
        self.epsilon = epsilon
        self.steps = steps
        super(Sinkhorn, self).__init__()

    def softmin(self, A, log_P):
        ret = - self.epsilon * torch.logsumexp(-A / self.epsilon, dim=1)
        return ret

    def forward(self, C, a, b):
        Cmax = C.max()
        C = C/Cmax

        log_a = torch.log(a)
        log_b = torch.log(b)
        log_v = torch.zeros_like(log_b)

        P = a[:, None] @ b[None]
        log_P = torch.log(P)

        previous_cost = 100000
        for step_i in range(self.steps):
            log_u = self.epsilon * log_a + \
                self.softmin(C - log_v.unsqueeze(0), log_P)

            log_v = self.epsilon * log_b + \
                self.softmin(C.transpose(1, 0) - log_u.unsqueeze(0),
                             log_P.transpose(1, 0))
            log_P = (log_u.unsqueeze(1) - C +
                     log_v.unsqueeze(0)) / self.epsilon

            P = torch.exp(log_P)
            cost = torch.sum(P * C) * Cmax

            if (abs(cost.detach().tolist() - previous_cost) < 1e-16 * Cmax):
                break
            previous_cost = cost.detach().tolist()

        return cost, P

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Y[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X, Y, P)

実際にSinkhornを用いた場合でもソートした場合と似たような解を得ることができた。

微分可能なソート

1次元の最適輸送問題をSinkhornで解くことで微分可能なソートを実装する。 やることは簡単で、入力の一方(ここでは$X$)をソートしたいデータ、もう一方($Y$)をソート済みの1次元点群としてSinkhornをやるだけでいい。 $Y$はソート済みの点群であれば何でもいいが、ここでは特に[0,1]の範囲に等間隔に存在する点群とした。

Y = torch.linspace(0, 1, n)
plt.scatter(X, torch.zeros(n))
plt.scatter(Y, torch.ones(n))

このとき、Sinkhornを使って点の対応関係を求めると以下のようになる。

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Y[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X, Y, P)

ここで、$Y$はソート済みであるので、Yの1つ目の点($y_1$)と対応づいた$X$側の点は、$X$の中で最も小さい値を持っていることがわかる (1次元の最適輸送は両方の点をソートした場合と一致するので)。同様に$y_i$と対応づいた点は$i$番目に小さい値を持っている。これは以下のように$P$の列方向のargmaxを見ると確認できる。

P.argmax(dim=1), torch.sort(X)[1]
(tensor([2, 4, 0, 3, 1]), tensor([2, 4, 0, 3, 1]))

$X$のソート済みの値はこのようにargmaxを計算しても解けるけど、以下のように輸送行列ともとのデータとの掛け算でも計算できる。この方法でやればSinkhorn + 掛け算なので微分可能になっていることに注意。

P.t() @ X * n
tensor([-1.0106, -0.6100,  0.0467,  0.2172,  0.4024])

普通にソートすると以下。

torch.sort(X)[0]
tensor([-1.0115, -0.6123,  0.0461,  0.2167,  0.4024])

Sinkhornで求めたソート前後の値は厳密には一致していない。これはエントロピー正則化の影響で$P$が厳密には{0, 1/n}の値に張り付いていないことが原因であると思われる。実際に値を見てみても1と0だけになっているわけではないことがわかる。

P
tensor([[7.4395e-42, 6.9331e-04, 1.9931e-01, 4.0038e-08, 1.3129e-24],
        [0.0000e+00, 4.6160e-38, 1.7486e-16, 4.6286e-04, 2.0000e-01],
        [1.9954e-01, 3.2266e-20, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.6854e-15, 6.9372e-04, 1.9954e-01, 9.3685e-09],
        [4.6154e-04, 1.9931e-01, 2.6548e-34, 0.0000e+00, 0.0000e+00]])

せっかくなのでnn.Moduleとして書くと以下のような感じ

class SoftSort(nn.Module):
    def __init__(self, epsilon, steps):
        super(SoftSort, self).__init__()
        self.sinkhorn = Sinkhorn(epsilon=epsilon, 
                                steps=steps)
    def forward(self, x):
        anchors = torch.linspace(0,1,len(x))
        C = (x[:, None] - anchors[None])**2
        a = torch.ones(len(x))/len(x)
        b = torch.ones(len(anchors))/len(anchors)
        _, P = self.sinkhorn(C, a, b)
        return x @ P / a

勾配を使った最適化

Sinkhornを使ってソートが解けることを示した。このアルゴリズムの利点は、ソート後の値がSinkhorn+行列の掛け算のみで得られるため勾配が容易に計算できることである。せっかくなので、以下のような適当な問題を作って入力の$X$を最適化することを考える。

$X$のソート後の値がある点群$Z$と等しくなるように$X$を最適化。

$$\min_{X} \|Z - \text{sort}(X)\|^2$$

$Z$として、[-1, 0]に等間隔に配置された点を考えることにする。 (sortの部分だけbackprop切っても似たような解が得られると思うので例として微妙だけど簡単なのでこれをやる)

from torch.optim import SGD
soft_sort = SoftSort(epsilon=1e-2, steps=10000)
X.requires_grad = True
Z = torch.linspace(-1, 1, n)

optimizer = SGD([X], lr=1e-2)
hist_loss = []
for i in range(300):
    soft_sorted_x = soft_sort(X)
    loss = ((Z - soft_sorted_x)**2).sum()
    loss.backward()
    hist_loss.append(loss.detach().tolist())
    optimizer.step()
    optimizer.zero_grad()
torch.sort(X)[0], Z, soft_sorted_x
(tensor([-1.0200e+00, -5.0110e-01, -3.2722e-04,  5.0038e-01,  1.0193e+00],
        grad_fn=<SortBackward>),
 tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000]),
 tensor([-9.9967e-01, -5.0004e-01, -3.2446e-04,  4.9931e-01,  9.9897e-01],
        grad_fn=<DivBackward0>))
plt.plot(hist_loss)

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Z[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X.detach(), Z, P.detach())

ここで今回考えた最適化問題はソート後の$X$が等間隔に並ぶようにしただけであり、$X$の値自体はソートされていない状態のままであることに注意。

X
tensor([-3.2722e-04,  1.0193e+00, -1.0200e+00,  5.0038e-01, -5.0110e-01],
       requires_grad=True)

まとめ

Sinkhornを使ってソート済みの1次元点列との間の最適輸送問題を解くことで、微分可能な形でソートを計算できるという内容の論文を紹介した。 論文ではもう少し意味ありそうな実験をやっている他、微分可能なquantileの計算もやっている(ソート後の値の適切なインデックスのものを取ってくるだけでいい)。

Sinkhornも1次元の最適輸送問題がソートで解けることも知っていたのに、最適輸送で微分可能なソートができると最初に聞いたときは全然この方法を思いつかなかったので悔しい。

参考

Differentiable Ranking and Sorting using Optimal Transport

Sinkhorn iterationの勾配と輸送行列の関係

はじめに

最近,最適輸送を覚えてちょいちょい使っています. 特にSinkhorn iterationを用いることで,行列計算だけで最適輸送距離をだいたい計算でき,さらに自動微分で勾配が求まるため,色々な問題のロス関数として最適輸送を使うことができます. ところで,最適輸送の式はもともと線形計画問題なので,輸送計画の行列がコストに依存しないとすると,簡単に勾配(らしきもの)を求めることができます. 以下では,自動微分を用いたときと,この方法で求めた勾配を用いたときで,実際にどの程度の差がでるものなのか気になったので,適当なデータに対して両方の手法を適用してみて結果を比較します. 結果,どちらもそんなに変わらないように思えるので,MLのような雑な最適化で済むようなタスクでは自動微分じゃないほうがいいのではないかと思っています. このあたりのことについて,何か知見があれば教えていただけると嬉しいです.

最適輸送問題

適当な2つの2次元空間上の点群$X = \{x_1, \ldots, x_n\}$と$Y = \{y_1, \ldots, y_m\}$の間の最適輸送を考えます.

import numpy as np
import pylab as plt
from scipy.stats import norm

import torch
import torch.nn
from torch.optim import Adam

import ot

torch.random.manual_seed(1234)
X = torch.randn(10, 2)
Y = torch.randn(10, 2) + np.ones(2)*5

n = X.shape[0]
m = Y.shape[0]

a = torch.ones(n)/n
b = torch.ones(m)/m
plt.scatter(X.detach()[:,0], X.detach()[:,1])
plt.scatter(Y.detach()[:,0], Y.detach()[:,1], marker="x")
plt.show()

f:id:ksknw:20211118213219p:plain

これらの点群の間の最適輸送問題は以下のように定義されます.ただし,ground metric (点同士の距離)として$L_2$距離の二乗を使います.つまり,$C\in \mathbb R^{n\times m}$の成分$C_{ij} = \|x_i - y_j\|^2_2$です.

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt$$ ここで,$P$は輸送行列で$P_{ij}$は点$x_i$と$y_j$の間の対応関係の強さを表します.$U$は質量保存則を満たす行列の集合であり,$U=\{P\in [0,1]^{n\times m} | P \mathbf 1_m = \mathbf a, P^\top \mathbf 1_n = \mathbf b\}$と書くことができます. $\mathbf a, \mathbf b$はそれぞれ$X, Y$に含まれる各点の重みを表すベクトルです.ここでは単純に各点群に含まれるすべての点が同じ重みを持つとします.つまり$\mathbf a = \mathbf 1_n/n$,$\mathbf b = \mathbf 1_m/{m}$です. 以下ではこの最小化問題の解を輸送コストと呼びます.(輸送コストのrootは2-Wasserstein距離とも呼ばれます.)

この問題は線形計画問題なので,適当なソルバーに突っ込むと,そこそこのサイズの点群に対しても解を求めることができます. しかし,ソルバー(例えば単体法)は微分できないので,輸送コストをロス関数として用いたいなどの機械学習における用途としては使いにくいです. また,点群のサイズを$N$として,計算量も $O({N}^{3})$ かかってしまいます.

そこで,Sinkhorn iterationを使ってエントロピー正則化つきの最適輸送問題を解くということがよくやられています. エントロピー正則化つきの最適輸送問題は以下のように書けます.

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt - H(P)$$

ここで,$H(P)$は輸送行列$P$のエントロピーを表し,$H(P) = - \sum_{i,j} P_{ij} \log P_{ij}$です. $\epsilon$はエントロピー正則化の強さを表します.

この最適化問題に対して,Sinkhorn iterationと呼ばれる反復法を用いて最適解を見つけることができることが知られています. Sinkhorn iterationにはいくつかの表記の仕方がありますが,ここでは後ろの実装と合わせるためにsoftminを用いた表記を用います.

$$ \mathbf f^{(l+1)} = \text{Min}_\epsilon ^{\text{row}} (S(\mathbf f^{(l)}, \mathbf g^{(l)})) + \mathbf f^{(l)} + \epsilon \log(\mathbf a) \\ \mathbf g^{(l+1)} = \text{Min}_\epsilon ^{\text{col}} (S(\mathbf f^{(l+1)}, \mathbf g^{(l)})) + \mathbf g^{(l)} + \epsilon \log(\mathbf b) $$ ここで,$S(\mathbf f, \mathbf g)_{ij} = C_{ij} - f_i - g_j$です. また,$\text{Min}_\epsilon ^{\text{row}}$は行方向へのsoftminを表します.softminは以下の様な関数です.

$$\min_\epsilon \mathbf z = -\epsilon \log \sum_{i} e^{-z_i/\epsilon}$$

$\mathbf f, \mathbf g$を適当に初期化して,この反復を収束するまで行います. その後以下のように$P$を求めることができます.

$$P_{ij} = e^{f_i/\epsilon} e^{-C_{ij}/\epsilon} e^{g_j/\epsilon}$$

では実際に実装して$P$を求めます.

epsilon = 0.01

def softmin(z, epsilon):
    return -torch.logsumexp(-z/epsilon, dim=1)*epsilon

def sinkhorn(C):
    n,m = C.shape[:2]
    f = torch.ones(n).double()
    g = torch.zeros(m).double()
    C = torch.sum((X[:, None] - Y[None])**2, dim=2)
    C_max = C.max()
    C = C/C.max()

    for i in range(200):
        S = C - f[:, None] - g[None]
        f = softmin(S, epsilon=epsilon) + f + epsilon * torch.log(a)
        S = C - f[:, None] - g[None]
        g = softmin(S.t(), epsilon=epsilon) + g + epsilon * torch.log(b)
    P = torch.diag(torch.exp(f/epsilon)) @ torch.exp(-C/epsilon) @ torch.diag(torch.exp(g/epsilon))
    return (P * C).sum() * C_max, P
C = torch.sum((X[:, None] - Y[None])**2, dim=2)
cost, P = sinkhorn(C)
plt.imshow(P.detach())
plt.show()

f:id:ksknw:20211118214629p:plain

plt.scatter(X.detach()[:,0], X.detach()[:,1])
plt.scatter(Y.detach()[:,0], Y.detach()[:,1], marker="x")
for i,p_i in enumerate(P):
    for j,p_ij in enumerate(p_i):
        plt.plot([X.detach()[i, 0], Y.detach()[j,0]], 
                 [X.detach()[i, 1], Y.detach()[j,1]], 
                 alpha=float(p_ij.detach().numpy())*2, c="gray")
plt.show()

f:id:ksknw:20211118214645p:plain

Sinkhorn iterationの微分

先述したように,Sinkhorn iterationは単に行列の掛け算の形で書くことができます (上ではsoftminを使った表記でしたが,softminも$\epsilon>0$で微分可能な関数です). そのため,Sinkhornは微分可能な演算になっています. 実際にpytorchなどを用いると,容易にSinkhornで求まる輸送コストをロスに使ってパラメータの最適化を行うことができます. ここでは簡単に$X$と$Y$の間の輸送コストが小さくなるように$X$を動かしていく問題を,Sinkhorn iterationの勾配を使って解くことを考えます.

original_X = X.clone().detach()
%%time

X = original_X.clone().detach()
X.requires_grad = True

hist_X = []
hist_X.append(X.detach().numpy().copy())
hist_cost = []


optimizer = Adam([X], lr=0.1)
for i in range(100):    
    C = torch.sum((X[:, None] - Y[None])**2, dim=2)
    cost, P = sinkhorn(C)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()
    hist_X.append(X.detach().numpy().copy())
    hist_cost.append(cost.detach().numpy())
hist_X = np.array(hist_X)
CPU times: user 8.83 s, sys: 8.34 ms, total: 8.84 s
Wall time: 8.86 s

ロスの履歴と移動の軌跡をプロットします.

plt.plot(hist_cost)
plt.ylabel("cost")
plt.xlabel("iterations")
plt.show()

f:id:ksknw:20211118214812p:plain

plt.scatter(X.detach()[:,0], X.detach()[:,1])
plt.scatter(Y.detach()[:,0], Y.detach()[:,1], marker="x")

plt.scatter(original_X.detach()[:,0], original_X.detach()[:,1], c="C0")

plt.plot(hist_X[:,:,0], hist_X[:,:,1], c="C0", lw=0.4, alpha=0.6)
plt.show()

f:id:ksknw:20211118214826p:plain

左にあった$X$(青点)が$Y$(オレンジ×)の方に移動しているのがわかります. 実際にSinkhorn iterationの微分を使って輸送コストが小さくなるように最適化を実行することができました.

$P$を用いたSinkhornの微分

ようやく本題です. エントロピー正則化つきの場合の輸送コスト$L_\epsilon$は以下のように書くことができます. $$ {L}_{\epsilon} = \big {\lt} {P}^{*}, C \big \gt - \epsilon H(P^{\ast}) $$ ここで,$P^*(C) = \text{argmin}_{P\in U(a,b)} \left \lt P, C \right \gt$ です.

このとき,$L$の$C$に関する微分は以下です. $$ \frac{\partial {L}_{\epsilon}}{\partial C} = P^{\ast}(C) + \frac{\partial P^{\ast} (C)}{\partial C} - \epsilon \frac{\partial H(P^{\ast}(C))}{\partial C} $$ 右辺第2項, 第3項はコストが少し変わったときの最適な輸送行列の変化に依存していますが,例えば勾配を使って点を少しずつ動かすようなタスクでは輸送行列は大体の場合同じで,たまに切り替わるという挙動をしそうな気がします.この場合は,右辺第2,3項は無視しても最適化の結果には大きな影響がないのではないかという気がします.

Sinkhorn iterationsでは自動微分を使って勾配を計算しました.自動微分の計算のためには$f$や$g$の中間値を保存し,かつ,backward計算する必要があります.これに対して$P^*$はSinkhorn iterationsの最後の値を保存しておくだけで計算可能です(そもそもSinkhorn じゃなくても計算できます). 実際にやってみます.

import torch.autograd
class Sinkhorn_P(torch.autograd.Function):
    @staticmethod
    def forward(ctx, C):
        n,m = C.shape[:2]
        f = torch.ones(n).double()
        g = torch.zeros(m).double()
        C_max = C.max()
        C = C/C.max()

        for i in range(200):
            S = C - f[:, None] - g[None]
            f = softmin(S, epsilon=epsilon) + f + epsilon * torch.log(a)
            S = C - f[:, None] - g[None]
            g = softmin(S.t(), epsilon=epsilon) + g + epsilon * torch.log(b)
        P = torch.diag(torch.exp(f/epsilon)) @ torch.exp(-C/epsilon) @ torch.diag(torch.exp(g/epsilon))
        ctx.save_for_backward(P)
        return (P * C).sum() * C_max, P

    @staticmethod
    def backward(ctx, grad_output1, grad_output2):
        P, = ctx.saved_tensors
        
        return P*grad_output1
    
def sinkhorn_P(C):
    return Sinkhorn_P.apply(C)
%%time

X_P = original_X.clone().detach()
X_P.requires_grad = True

hist_X_P = []
hist_X_P.append(X_P.detach().numpy().copy())

optimizer = Adam([X_P], lr=0.1)
for i in range(100):    
    C = torch.sum((X_P[:, None] - Y[None])**2, dim=2)
    cost, P = sinkhorn_P(C)
    cost.backward()
    optimizer.step()
    optimizer.zero_grad()
    hist_X_P.append(X_P.detach().numpy().copy())
hist_X_P = np.array(hist_X_P)
CPU times: user 3.48 s, sys: 3.4 ms, total: 3.48 s
Wall time: 3.48 s
plt.scatter(X_P.detach()[:,0], X_P.detach()[:,1], c="C2")
plt.scatter(Y.detach()[:,0], Y.detach()[:,1], marker="x", c="C1")
plt.scatter(original_X.detach()[:,0], original_X.detach()[:,1], c="C0")

plt.plot(hist_X_P[:,:,0], hist_X_P[:,:,1], c="C2", lw=0.4, alpha=0.6)
plt.show()

f:id:ksknw:20211118215619p:plain

似たような感じで$X$を更新することができているように見えます.重ねて表示してみます.

plt.scatter(X_P.detach()[:,0], X_P.detach()[:,1], c="C2")
plt.scatter(Y.detach()[:,0], Y.detach()[:,1], marker="x", c="C1")
plt.scatter(original_X.detach()[:,0], original_X.detach()[:,1], c="C0")

plt.plot(hist_X_P[:,:,0], hist_X_P[:,:,1], c="C2", lw=0.4, alpha=0.6)
plt.scatter(X_P.detach()[:,0], X_P.detach()[:,1], c="C2")

plt.plot(hist_X[:,:,0], hist_X[:,:,1], c="C0", lw=0.4, alpha=0.6)
plt.show()

f:id:ksknw:20211118215636p:plain

多少軌跡がずれていますが,最終的な収束点はほぼ同じように見えます.また,実行時間はSinkhornの自動微分を用いると10s程度かかっているのに対して,3.8sなので倍以上速くなっています(今回はCPUで実行したのでGPUだとまた違うかも).

おわりに

Sinkrhon iterationをbackwardして求めた勾配と最適な輸送行列を使った勾配っぽいものを比較しました. あんまり真面目に検証してないですが,少なくとも今回のタスクでは$P^*$を勾配として使っても問題なさそうでした. 点群の一致でしか検証してないので,勾配が大きくずれる条件やSinkhorn iterationを真面目にbackpropしない最適化できない問題(凸性が重要なときとか? Pを使うと振動するとか?)があるかもしれません. このあたりのことについて,何か知見があれば教えていただけると嬉しいです.

参考

[1803.00567] Computational Optimal Transport

Soft-DTWで台風軌跡のbarycenterを求める

はじめに

2つの時系列データを比較し,それらの間の遠さを知る方法として,Dynamic Time Warping (DTW)があります. DTWは,2つの時系列データの各フレームを対応付けることによって定義されます. このとき,対応付けは「対応付けられたフレーム間の距離が最小になる」ように定められます. この操作(実際には動的計画法)は微分可能ではない1ため,DTWを例えば深層学習モデルの目的関数として用いると最適化が安定しないなどの問題がありました.

soft-DTW (Marco Cuturi and Mathieu Blondel, ICML217) はDTWを微分可能な形に拡張するものです. 実装としては非常に簡単で,DTWの動的計画法の中で出てくるmin関数をsoft-min関数に置き換えるだけで実現可能です. ここでは,soft-DTWをPyTorchによって実装し,これを目的関数とした最適化問題の例として,台風の軌跡のbarycenterを求める問題を解くことで,DTWとsoft-DTWを比較し,soft-DTWによってより自然な(?)barycenterが得られることを示します.

ここでは,PyTorchを用いて実装を行いますが,適当な実装なので,遅いです. 例えば,tslearnの実装を用いることでより高速に(かつバグの可能性が少ない)結果を得ることができます.

Dynamic Time Warping

DTWについては過去に実装したものがあるので,詳しくはそちらを参照ください. DTWでは2つの時系列データ$X = x_1, \dots ,x_n$, $Y = y_1, \dots y_{m}$が与えられたとき,これらの間の距離を以下のように定義します.

$$ DTW(X, Y) \equiv \min_{A\in \mathcal A} \big<A, \Delta(X, Y) \big > $$

ここで,$\big< \big>$は行列の内積(要素ごとに積をとって和)を表し,$\Delta(X,Y) \in \mathbb R^{n \times m}$は各要素$\Delta_{ij}$がフレーム間の距離$\Delta_{ij} = d(x_i, y_j)$を表すコスト行列です. $A$はアライメント行列と呼ばれる,どのフレーム間を対応付けるのかを表す行列であり,$A_{ij} \in {0,1}$です. $A_{ij}=1$は$x_i$と$y_j$を対応付けることを意味します. また,$\mathcal A$はアライメント行列の制約条件を満たす行列の集合を表します.

例として,以下の2つの時系列データ入力として,適当なsin波に対してDTWを適用します.

import numpy as np
import pylab as plt
import pandas as pd
import matplotlib.gridspec as gridspec
import torch
T = 50
t = .4

X = np.sin(np.array(range(T))/10)
Y = np.sin((np.array(range(T))/10 + t*np.pi))

plt.plot(X)
plt.plot(Y)
plt.show()

f:id:ksknw:20200926131547p:plain

DTWは以下のような動的計画法によって得ることができます. ここでmは累積誤差を表しています. DTWにおける動的計画法については,こちらのブログのアニメーションがわかりやすいです.

X,Y = torch.FloatTensor(X), torch.FloatTensor(Y)

Delta = (X[:,None] - Y[None, :])**2

def dtw(Delta):
    S, T = Delta.shape

    m = torch.zeros(S, T)
    m[0,0] = Delta[0,0]
    for i in range(1,S):
        m[i,0] = m[i-1,0] + Delta[i,0]
    for j in range(1,T):
        m[0,j] = m[0, j-1] + Delta[0,j]

    for i in range(1,S):
        for j in range(1,T):
            m[i,j] = Delta[i,j] + min(m[i-1,j], m[i,j-1], m[i-1,j-1])
    return m[-1, -1]
dtw(Delta)
tensor(6.1637)

次に,アライメント行列を求めます. アライメント行列は,動的計画法のminを求めている部分でargminを保存しておくことで,容易に求めることができますが,ここではDTWのコスト行列に対する勾配を利用することで,アライメント行列を求めます. もう一度DTWの式を眺めます.

$$ \begin{align} \text{dtw}(X, Y) & \equiv \min_{A\in \mathcal A} \big < A, \Delta(X, Y) \big > \\ &= \big < A^*, \Delta(X, Y) \big > \end{align} $$

DTWの値は最適なアライメント行列 $A^ *$ とコスト行列 $\Delta(X,Y)$ の内積です. このため, $A^*$ のij成分は以下のように求めることができます.

$$ A^*_{ij} = \frac{\partial \text{dtw}(X, Y)}{\partial \Delta(X,Y)_{ij}} $$

動的計画法の勾配を明に実装するのはちょっと面倒ですが,autogradを使うと以下のように容易に求めることができます.

Delta.requires_grad = True
dtw(Delta).backward()
A = Delta.grad.detach().numpy()

それっぽい絵を描くと以下です.

def plot_path(A, X, Y, D):
    plt.figure(figsize=(5,5))
    gs = gridspec.GridSpec(2, 2,
                       width_ratios=[1,5],
                       height_ratios=[5,1]
                       )
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax4 = plt.subplot(gs[3])

    ax2.get_xaxis().set_ticks([])
    ax2.get_yaxis().set_ticks([])    
    ax2.pcolor(A)
    
    ax4.plot(X)
    ax4.set_xlabel("$X$")
    ax1.invert_xaxis()
    ax1.plot(Y, range(len(Y)), c="C1")
    ax1.set_ylabel("$Y$")

    ax2.set_xlim(0, len(X))
    ax2.set_ylim(0, len(Y))
    plt.show()
plot_path(A, X, Y, Delta.detach().numpy())

f:id:ksknw:20200926131604p:plain

対応付けられたフレーム同士を線でつなぐとこのように1対1に対応付けられていることがわかります.

plt.plot(X)
plt.plot(Y)

for i,j in np.array(np.meshgrid(np.arange(len(X)), np.arange(len(Y)))).reshape(2, -1).T:
    plt.plot([i,j], [X[i], Y[j]], alpha=A[i,j], c="gray")

f:id:ksknw:20200926131617p:plain

DTWの勾配を利用することでアライメント行列を求めることができました.

次にDTWを目的関数とした以下の最適化問題を勾配法によって解くことを考えます.

$$ \min_\theta \text{dtw}(X,Y_\theta) $$

このような問題は,例えば,RNNから出力される時系列データ$Y_\theta$に関して,$\theta$,つまり,RNNのパラメータを最適化したい場合などがあります.

DTWのYに関する勾配は

$$ \begin{aligned} \frac{\partial \text{dtw}(X,Y)}{\partial Y_j} & = \sum_{i} \frac{\partial \text{dtw}(X, Y)}{\partial \Delta(X,Y)_{ij}} \frac{\partial \Delta(X,Y)_{ij}}{\partial Y_j} \ &= \sum_{i} A_{ij}^* \frac{\partial \Delta(X,Y)_{ij}}{\partial Y_j} \end{aligned} $$

として求めることができます. このように,DTWを目的関数としても勾配を求めることはできるため,勾配を用いてYを徐々に更新することは可能です. しかし,勾配の形を見てもわかるとおり,この勾配は,$Y$が変化してもアライメント行列$A_{ij}^ *$は変化しないことを仮定しています. しかし,実際には時系列データ$Y$を更新していくと,あるとき,アライメント行列$A_{ij}^ *$は切り替わるような動きをします. このような挙動は最適化の性能を落としてしまいます.

Softminとmin

soft-DTWの説明の前に,softmin関数を導入します. これは以下のように定義することができます.

$$ \text{min}^\gamma(x_1, \dots, x_n) \equiv \begin{cases} -\gamma \log \sum_{i=1}^n \exp (-x_i / \gamma) & (\gamma>0)\\ \min (x_1, \dots, x_n) &(\gamma=0) \end{cases} $$

softmin関数はminを滑らかにしたものであり,あとでみるように$\gamma$を0に近づけていくと,minに近づきます.

余談: ニューラルネットワークをやっているひとはsoftmaxを知っていると思います. 名前がややこしくて混乱しますが,NNのsoftmax関数は0~1の出力で,自分が(単独で)最大であれば1に近づく関数という意味で,soft-argmax関数というほうが正しいでしょう.これに対して,ここで定義したsoftminはminを滑らかにしたものという意味で正しくsoftmin関数です.

softmin関数の挙動について,詳しく見てみます. 入力$x_1, x_2$が与えられたとき,minの等高線,および,softminの出力の等高線は以下のグラフのようにかけます.

def softmin(a, gamma, dim=0):
    return -gamma * torch.logsumexp(-a / gamma, dim=0)

minの等高線

grid = torch.stack(torch.meshgrid( torch.linspace(0,1, 100), torch.linspace(0,1,100))).reshape(2, -1)

fig, ax = plt.subplots(1,3,figsize=(15,5))
for ax_i in ax:
    ax_i.axis("equal")
    ax_i.set_xlabel("$x_1$")
    ax_i.set_ylabel("$x_2$")
    
min_values = torch.min(grid, dim=0)[0]
cs = ax[0].contour(min_values.numpy().reshape(100,100))
ax[0].clabel(cs, inline=1, fontsize=10)
ax[0].set_title("$\gamma=0$")

softmin_values = softmin(grid, 0.01, dim=0)
cs = ax[1].contour(softmin_values.numpy().reshape(100,100))
ax[1].clabel(cs, inline=1, fontsize=10)
ax[1].set_title("$\gamma=0.01$")

softmin_values = softmin(grid, 0.1, dim=0)
cs = ax[2].contour(softmin_values.numpy().reshape(100,100))
ax[2].clabel(cs, inline=1, fontsize=10)
ax[2].set_title("$\gamma=0.1$")

plt.show()

f:id:ksknw:20200926131630p:plain

図より,minは直角な等高線をしているのに対して,softminでは$\gamma$を大きくするにつれて徐々に角が丸まっていくのがわかります. これらの関数の勾配は等高線に対して直角方向に対応します. min関数では勾配はどちらの値が小さいかに対応して,(1,0)もしくは(0,1)になります. これに対して,softminでは入力の値が近いときには,斜め方向に勾配をもつことがわかります.

Soft-DTW

DTWでは勾配の計算に,そのときの最適なアライメント行列$A^ *$のみを考慮し,他のアライメント行列の可能性を一切考慮していませんでした. このため,時系列データ$Y$が少し変化した際に,$A^ *$が急激に変化することがあり,これが最適化の邪魔をしてしまうことがありました.

Soft-DTWは,最適なアライメント行列だけを考えるのではなく,すべてのアライメント行列を重み付きで考慮します. Soft-DTWは以下のように定義されます.

$$ \text{dtw}^\gamma(X, Y) \equiv \text{min}_{A\in \mathcal A}^\gamma \big<A, \Delta(X, Y) \big > $$

Soft-DTWはDTWと同様に動的計画法によって解くことができます(minをsoftminに置き換えるだけです).

def softdtw(Delta, gamma):
    S, T = Delta.shape

    m = torch.zeros(S, T)
    m[0,0] = Delta[0,0]
    for i in range(1,S):
        m[i,0] = m[i-1,0] + Delta[i,0]
    for j in range(1,T):
        m[0,j] = m[0, j-1] + Delta[0,j]

    for i in range(1,S):
        for j in range(1,T):
            m[i,j] = Delta[i,j] + softmin(torch.stack([m[i-1,j], m[i,j-1], m[i-1,j-1]]), gamma)
    return m[-1, -1]
softdtw(Delta, 0.01)
tensor(5.7458, grad_fn=<SelectBackward>)

Soft-DTWのコスト行列に関する勾配は以下のようになります.

$$ \frac{\partial \text{dtw}^\gamma(X, Y)}{\partial \Delta(X,Y)_{ij}} = \frac{1}{Z}\sum_{A\in \mathcal A} \exp \big(- \big< A, \Delta \big> /\gamma \big) A_{ij} $$

ここで,$Z = \sum_{A\in \mathcal A} \exp \big(- \big< A, \Delta \big> /\gamma \big)$です. この右辺をよく見ると,Aの確率分布$p(A)$をGibbs分布$p(A) = \frac{1}{Z}\sum_{A\in \mathcal A} \exp \big(- \big< A, \Delta \big> /\gamma \big)$としたときの,アライメント行列の期待値$\mathbb E_\gamma [A]$になっていることがわかります.

論文の中では明なbackwardルールが記載されていますが,これもautogradを使えば,特に何も気にせず勾配を得ることができます(速度には差があるかもしれません).

Delta.grad.zero_()
softdtw(Delta, 0.01).backward()
A = Delta.grad.detach().numpy()
plot_path(A, X, Y, Delta.detach().numpy())

f:id:ksknw:20200926131650p:plain

DTWのときとは異なり,アライメント行列がぼやっとしていることがわかります.これは最適な1つのアライメント行列だけでなく,コストがあまりかわらない別のアライメント行列も考慮されていることに対応します.

plt.plot(X)
plt.plot(Y)

for i,j in np.array(np.meshgrid(np.arange(len(X)), np.arange(len(Y)))).reshape(2, -1).T:
    plt.plot([i,j], [X[i], Y[j]], alpha=A[i,j], c="gray")

f:id:ksknw:20200926131706p:plain

最適化問題のロス関数として用いたときの比較

では実際に,最適化のロス関数としてDTWおよびsoft-DTWを用いたときの挙動を確認します. ここでは2017年から2020年までに日本に上陸した台風の軌跡データから,平均的な軌跡(barycenter)を求める問題を考えます. データは気象庁から公開されているものを用いました.

DTW barycenter $Y$を求める最適化は以下のように書くことができます.

$$ \min_Y \sum_{n=1}^N \text{dtw}(X_n,Y) $$

これは$N$個の時系列データ$X_1, \dots, X_N$が与えられたとき,これらとのDTW距離が最小になる$Y$を求める問題です. soft-DTWを用いる場合も同様に定義することができます. 一般にDTW barycenterはDBAなどの方法が用いられることが多いですが,ここでは,勾配法によってbarycenter $Y$を求めることにします.

とりあえずデータをプロットします

from glob import glob
from mpl_toolkits.basemap import Basemap
traces = []
for file in sorted(glob("./data/table*")):
    data = pd.read_csv(file, encoding="shift_jis")
    for ind in data["台風番号"].unique():
        temp = data[data["台風番号"] == ind]
        if sum(temp["上陸"]==1) > 0:
            traces.append(torch.FloatTensor(temp[["経度", "緯度"]].values))
def plot(traces, Y=None):
    m = Basemap(llcrnrlon=100.,llcrnrlat=10.,urcrnrlon=170.,urcrnrlat=60.,
                rsphere=(6378137.00,6356752.3142),
                resolution='h',projection='merc')
    m.fillcontinents()
    
    for tr in traces:
        tr_numpy = tr.detach().numpy()
        x,y = m(tr_numpy[:, 0], tr_numpy[:, 1])
        plt.plot(x,y,  marker="", lw=1, ls="-", c="gray", alpha=0.4)
        
    if Y is not None:
        Y_numpy = Y.detach().numpy()
        x,y = m(Y_numpy[:, 0], Y_numpy[:, 1])
        plt.plot(x,y,  marker="", lw=1, ls="-", c="C1")
plot(traces)

f:id:ksknw:20200926131724p:plain

barycenter $Y$の初期値として,ここでは,データとして与えられた軌跡の中から最もbarycenterに近いもの,つまり,他の時系列データとのDTWの合計が小さいものを選びます.

from itertools import combinations
dists = np.zeros((len(traces), len(traces)))

for i in range(len(traces)-1):
    for j in range(i+1, len(traces)):
        Delta = ((traces[i][:,None] - traces[j][None, :])**2).sum(axis=-1)
        dists[i,j] = dtw(Delta)
np.argmin(np.sum(dists + dists.T, axis=0))
8
Y = traces[8].clone()
Y = torch.FloatTensor(Y)
Y.requires_grad = True

では,実際に勾配法によって$Y$を更新してみます.

DTW ($\gamma=0$)

from torch.optim import Adam

optimizer = Adam([Y], lr=1)
from tqdm.notebook import tqdm
hist = []

epoch = 50

for e in tqdm(range(epoch)):
    loss = 0
    for X in traces:
        Delta = ((X[:,None] - Y[None, :])**2).sum(dim=-1)
        loss += dtw(Delta)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    hist.append(loss.detach().numpy())
    
plt.plot(hist)

f:id:ksknw:20200926131739p:plain

plot(traces, Y)

f:id:ksknw:20200926131756p:plain

ロスは収束してしますが,求まった軌跡はガタガタで平均的な台風の軌跡とは言えなさそうです. (そもそもDTW barycenterのみた目はあまり平均的な軌跡っぽくならないというのもありますが)

soft-DTW ($\gamma=1$)

次にsoft-DTWを用いてbarycenterを求めてみます.ロスがsoft-DTWになっていること以外は上と同じです.

Y_1 = traces[8].clone()
Y_1.requires_grad = True

optimizer = Adam([Y_1], lr=1)

hist = []

for e in tqdm(range(epoch)):
    loss = 0
    for X in traces:
        Delta = ((X[:,None] - Y_1[None, :])**2).sum(dim=-1)
        loss += softdtw(Delta, gamma=1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    hist.append(loss.detach().numpy())
plt.plot(hist)

f:id:ksknw:20200926131825p:plain

plot(traces, Y_1)

f:id:ksknw:20200926131845p:plain

期待と異なり,DTWを用いた場合とあまり変わらない結果になりました.

soft-DTW ($\gamma=100$)

$\gamma=1$ではDTWを用いた場合とあまりかわらないように見えたので,次は$\gamma=100$にしてbarycenterを求めてみます.

Y_100 = traces[8].clone()
Y_100.requires_grad = True

optimizer = Adam([Y_100], lr=1)

epoch = 50
hist = []

for e in tqdm(range(epoch)):
    loss = 0
    for X in traces:
        Delta = ((X[:,None] - Y_100[None, :])**2).sum(dim=-1)
        loss += softdtw(Delta, gamma=100)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    hist.append(loss.detach().numpy())
plt.plot(hist)

f:id:ksknw:20200926131903p:plain

plot(traces, Y_100)

f:id:ksknw:20200926131920p:plain

それらしい(?)結果が得られました. 並べて描くと以下です.

plt.figure(figsize=(15,5))

plt.subplot(131)
plot(traces, Y)
plt.title("DTW")

plt.subplot(132)
plot(traces, Y_1)
plt.title("Soft-DTW ($\gamma=1$)")

plt.subplot(133)
plot(traces, Y_100)
plt.title("Soft-DTW ($\gamma=100$)")
Text(0.5, 1.0, 'Soft DTW ($\\gamma=100$)')

f:id:ksknw:20200926131933p:plain

soft-DTWの$\gamma$を大きな値に設定することで,それらしいbarycenterを得ることができました. 一方で,これが,微分可能になったことで最適化がうまくできたことが原因かと言われると微妙です.

DTW barycenterはそもそもギザギザの形状を取りやすい性質があります. 例えば,西を大きく回って東に抜ける軌跡と,東側を抜けていく軌跡があるとき,平均的な軌跡としては,これらの真ん中を通過するような軌跡を期待しますが,DTWの合計を最も小さくしようとすると,例えば多くの点は東側を抜ける軌跡と一致させ,西に数点の点を取るギザギザの軌跡が選択されることがあります.これは,西側にある数点を,西を大きく回る軌跡の多くの点と対応付けることで,DTWの値を小さくすることができるためです. Soft-DTWの$\gamma$を大きくすると,ある1つのアライメント行列だけでなく,他のアライメント行列もロスとして考慮されます.このとき,西の数点以外の点と,西側を回る軌跡が対応付けられるため,結果として,barycenterを期待するようなものへと変化させる勾配が生まれている気がします.

まとめ

DTWを微分可能な形に拡張したsoft-DTWの実装を行いました. DTWとの違いはmin関数の代わりにsoftmin関数を用いるところで,実装はPyTorchの自動微分を用いることで容易に可能でした. 台風の軌跡データを用いてbarycenterを求め,soft-DTWの$\gamma$の値を大きくすることで,それらしいbarycenterが得られることを確認しました. 一方で,微分可能になったことの恩恵を受けている気はあまりしませんでした.barycenter以外の別の最適化問題を解かせてみると,大きな違いがあるかもしれません. また,普通,DTW barycenterを求めるときは普通DBA(今のbarycenterと他の系列をアライメントして対応付けられた点の平均でbarycenterを更新)を用います.このアライメントと平均を交互に求める方法はsoft-DTWであっても利用可能です. この意味でも別の最適化問題で挙動を見てみることは意味があると思います.

参考


  1. 正確には,対応付けをfixだとしたときの勾配のみが求まる.

matplotlibだけでアノテーターを作る

はじめに

たまに,特定のタスクのために変なアノテーターを作りたいときがある. 慣れているので,pythonでやりたい. これまでこういうときは,opencvを使って作っていたが,最近,matplotlibを使っても同じようなことができると知ったので,調べて使ってみる.

公式ドキュメントをeventを受け取る関数を作って,それを fig.canvas.mpl_connect という関数で登録することでマウスやキーボードのイベントを取得できるらしい. 以下ではよくある感じのアノテーターを作って,どんな感じでできるかを確認する.

画像にラベルをつける

  • ←→が押されたら前の画像,次の画像を表示する.
  • キーが押されたら,押されたキーを記録して次の画像へ移動
  • qが押されたら結果をpickleに保存して終了

矩形選択

  • ←→が押されたら前の画像,次の画像を表示する.
  • ドラッグで領域選択して座標を取得
  • 選択中は四角を表示する
  • cを押したら前回の結果をキャンセルする
  • qが押されたら結果をpickleに保存して終了
  • keyのイベント,マウス押し込み,マウスドラッグ,マウスリリースのそれぞれのイベントを検出する関数を作って登録する.
  • 想像よりだいぶ長くなってしまってイマイチ

f:id:ksknw:20200810231441p:plain

特定の物体が写ってる画像を選択

  • クリックされたら,クリックされた画像のインデックスを保存して,別の画像を表示する.
  • 人間性を証明するためにたまにやらされるやつ
  • マウスイベントではevent.inaxesとして,axisのオブジェクトが返ってくる.subplotの番号がほしいときは,self.plot_axes.index(event.inaxes)とかやると番号を取得できる.

まとめ

  • matplotlibを使ってアノテーターを作ってみた.
  • いくつかやってみた感じ,opencvを覚えてるなら,opencvで作るのと対して手間は変わらないかなという印象(subplot使えるぶんだけ有用かも).
  • plt.drawを使うと若干ラグが生じるときがあって,plt.pause()で適当に短い時間を指定したほうが軽快だった.
  • ここに書いたようなものだったら何でもできそうだけど,matplotlibの機能を使ったアノテーター(scatterの点の位置をドラッグして補正するとか)するときは便利かもしれない.公式ドキュメントにはbarプロットのバーをドラッグして動かすデモが載っている.

参考

Event handling and picking — Matplotlib 3.1.2 documentation

部分一致DTW(SPRING)の実装

はじめに

以前,2つの時系列データの距離的なものを測るアルゴリズムであるDTWについて,以下のようなものを書きました.

ksknw.hatenablog.com

DTWはいい感じのアルゴリズムですが,時系列データの全体と全体を対応付けるので,短いパターンを長い時系列データの中から見つけるみたいなタスクに直接用いることはできません(やろうと思えばできるけど計算量やばい). なんかもっといいやつないかなと思っていたところに,たまたま読んだ論文で使われていたSPRINGというアルゴリズムを知ったので,ここでは実装しつつちゃんと理解してみることにしました. 以下ではpythonを使ってSPRINGを実装します.このノートブックの全体は以下にあります.

github.com

目的

ここでは,時系列データ$X$からある時系列パターン$Y$と似ている区間を見つける問題について考えます. 例として,以下の図に示すような時系列データ$X$ (青線)と時系列パターン$Y$(オレンジ線)を考えます. ちなみにこれは気象庁配布されている名古屋の気温データです.

import numpy as np
import pylab as plt
import pandas as pd
import matplotlib.gridspec as gridspec
from tslearn.metrics import dtw_path
data = pd.read_csv("./data.csv", header=None)[1].values
X = data[:1000:4]
Y = data[1000::4]

plt.plot(X, label="X")
plt.plot(Y, label="Y")
plt.legend()
plt.show()

f:id:ksknw:20191228102715p:plain

ここでの目的は以下のような部分時系列(緑部分)を時系列データ$X$から検出することです.

from spring import spring
for path, cost in spring(X, Y, 80):
    plt.plot(X, c="gray", alpha=0.5)
    plt.plot(path[:,0], X[path[:,0]], C="C2")

f:id:ksknw:20191228102754p:plain

2つの時系列データの距離(のようなもの)を測る方法の1つとして,Dynamic Time Warping (DTW)があります. DTWについて詳しくは前書いたやつを参照. しかし,DTWは2つの時系列データ全体同士を比較するものなので,これを今回の目的にそのまま用いることはできません. 時系列データの部分一致の問題を扱うアルゴリズムの1つとして,SPRINGがあります. 以下では,DTWについて簡単に説明したあと,部分一致DTW,そして,SPRINGについて説明します.

DTWと動的計画法

DTWでは2つの時系列データ$X = [X^{(1)}, \dots, X^{(T_x)}]$と時系列データ$Y = [Y^{(1)}, \dots, Y^{(T_y)}]$の距離的なもの(以下ではDTW距離)を求めます. DTWでは2つの時系列データに含まれる各フレーム$X_i$, $Y_j$を,フレーム間の距離の総和が最小になるように対応付けを求めます. このときの,フレーム間の距離の総和がDTW距離です.

ここではフレーム間の距離として二乗距離を考えます. 各フレーム間の距離を並べた行列$\Delta$, $\Delta_{i,j}= \|X^{(i)} - Y^{(j)} \|^2_2$は以下のようなものです. (以下では,説明のために,一旦同じぐらいの長さの時系列データ$X,Y$を考えます.)

X = data[280:400:4]
Y = data[1000::4]
D = (np.array(X).reshape(1, -1) - np.array(Y).reshape(-1, 1))**2
plt.imshow(D, cmap="Blues")
plt.xlabel("$X$")
plt.ylabel("$Y$")
plt.show()

f:id:ksknw:20191228102848p:plain

次にDTWではフレーム間の対応付け(アライメント)を求めます. 実際に求まるアライメントを以下の図に示します.

def dist(x, y):
    return (x - y)**2

def get_min(m0, m1, m2, i, j):
    if m0 < m1:
        if m0 < m2:
            return i - 1, j, m0
        else:
            return i - 1, j - 1, m2
    else:
        if m1 < m2:
            return i, j - 1, m1
        else:
            return i - 1, j - 1, m2

def dtw(x, y):
    Tx = len(x)
    Ty = len(y)

    C = np.zeros((Tx, Ty))
    B = np.zeros((Tx, Ty, 2), int)

    C[0, 0] = dist(x[0], y[0])
    for i in range(Tx):
        C[i, 0] = C[i - 1, 0] + dist(x[i], y[0])
        B[i, 0] = [i-1, 0]

    for j in range(1, Ty):
        C[0, j] = C[0, j - 1] + dist(x[0], y[j])
        B[0, j] = [0, j - 1]

    for i in range(1, Tx):
        for j in range(1, Ty):
            pi, pj, m = get_min(C[i - 1, j],
                                C[i, j - 1],
                                C[i - 1, j - 1],
                                i, j)
            C[i, j] = dist(x[i], y[j]) + m
            B[i, j] = [pi, pj]
    cost = C[-1, -1]
    
    path = [[Tx - 1, Ty - 1]]
    i = Tx - 1
    j = Ty - 1

    while ((B[i, j][0] != 0) or (B[i, j][1] != 0)):
        path.append(B[i, j])
        i, j = B[i, j].astype(int)
    path.append([0,0])
    return np.array(path), cost, C
path, dtw_dist, C = dtw(X, Y)
def plot_path(paths, A, B, D):
    plt.figure(figsize=(5,5))
    gs = gridspec.GridSpec(2, 2,
                       width_ratios=[1,5],
                       height_ratios=[5,1]
                       )
    ax1 = plt.subplot(gs[0])
    ax2 = plt.subplot(gs[1])
    ax4 = plt.subplot(gs[3])

    ax2.pcolor(D, cmap=plt.cm.Blues)
    ax2.get_xaxis().set_ticks([])
    ax2.get_yaxis().set_ticks([])
    
    for path in paths:
        ax2.plot(path[:,0]+0.5, path[:,1]+0.5, c="C3")
    
    ax4.plot(A)
    ax4.set_xlabel("$X$")
    ax1.invert_xaxis()
    ax1.plot(B, range(len(B)), c="C1")
    ax1.set_ylabel("$Y$")

    ax2.set_xlim(0, len(A))
    ax2.set_ylim(0, len(B))
    plt.show()
plot_path([np.array(path)], X, Y, D)

f:id:ksknw:20191228102911p:plain

図の中で,青線,オレンジ線はそれぞれ$X$, $Y$を表しています. 右上の図の青色は$\Delta$を表しており,1本の赤線(パス)で表したものがアライメントされたフレームを表します. ここで,アライメントは以下の3つの制約条件を満たすもののなかで,パス上のコストが最小になるものです.

  • 境界条件: 両端が左下と右上にあること
  • 単調性: 左下から始まり,→,↑,➚のいずれかにしか進まないこと
  • 連続性: 繋がっていること

フレームの対応付けを図に書くと以下です. 境界条件は$X,Y$の両端のフレームを対応付けること,単調性は対応づけがクロスしないこと,連続条件は対応付けられないフレームが存在しないことを表しています.

for line in path:
    plt.plot(line, [X[line[0]], Y[line[1]]], linewidth=0.8, c="gray")
plt.plot(X)
plt.plot(Y)
plt.show()

f:id:ksknw:20191228102926p:plain

DTWではこれらの制約条件を満たすパスの集合$\mathcal A$のうち,パス上の距離$\Delta$の総和が最小となるパスを求めるアルゴリズムです. 式で書くと以下です. $$\min_{A \in \mathcal A} \big < A, \Delta \big >$$ ただし,$\big < \cdot, \cdot \big >$は行列の内積を意味します.

この最適化問題動的計画法を使って解くことができます. 動的計画法では,左下のマスからスタートして,各マスに到達するため最小の累積コストを1マスずつ求めます. このとき,→,↑,➚のいずれかにしか移動できないので,$i+1, j+1$マスの最小の累積和$C$は $$C_{i,j} = \min \big(C_{i, j+1}, C_{i+1, j}, C_{i+1, j+1} \big)$$ で求めることができます. ただし,左下のマス($i=1, j=1$)では,$C_{1,1}=\Delta_{1,1}$であり,下一列($i\neq 1, j=1$)のときは, $$C_{i+1, 1} = C_{i, 1} + \Delta_{i+1, 1}$$ です. また,左一列も同様にして $$C_{i, j+1} = C_{1, j} + \Delta_{1, j+1}$$ です.

累積コストを求めるときに使ったマス($\mathrm{argmin} (C_{i, j+1}, C_{i+1, j}, C_{i+1, j+1})$)を記録しておくことで,終端(右上)にたどり着いた時,この記録を逆にたどることで,累積コストを最小にするパスを求めることができます. この部分のアルゴリズムについては,こちらのアニメーションがわかりやすいです. $C$をプロットすると以下です.

plt.pcolor(C, cmap="Blues")
plt.show()

f:id:ksknw:20191228102940p:plain

DTWを用いることで,時系列データ間の類似度を求めることができますが,今回ような長い時系列データから短い時系列パターンと一致する区間を検出するタスクに直接用いることはできません. ちなみに時系列データと時系列パターンを入力として,DTWをやってみると以下のようなパス,アライメントが得られます.

X = data[:1000:4]
Y = data[1000::4]

D = (np.array(X).reshape(1, -1) - np.array(Y).reshape(-1, 1))**2
path, dtw_dist = dtw_path(X, Y)
plot_path([np.array(path)], X, Y, D)

for line in path:
    plt.plot(line, [X[line[0]], Y[line[1]]], linewidth=0.8, c="gray")
plt.plot(X)
plt.plot(Y)
plt.show()

f:id:ksknw:20191228103002p:plain

f:id:ksknw:20191228103014p:plain

むりやり全体と全体を対応付けようとするため,よくわからない感じになっています.

部分一致DTW

次に,時系列データ$X=[X^{(1)}, \dots, X^{(T_x)}]$から,時系列パターン$Y = [Y^{(1)}, \dots, Y^{(T_y)}], (T_x > T_y)$に,最もDTW距離が近い区間を1箇所見つける問題を考えます. ナイーブな方法として,ありうる全ての部分時系列$[X^{(t)},...,X^{(s)}]$とパターン$Y$とのDTW距離を求め,最も小さいDTW距離をもつ区間を検出する方法が考えられます. しかし,この方法は時系列データ$X$の長さが長くなると計算量が大きくなってしまいます.

SPRINGでは,DTWの動的計画法を一部修正することによって,部分一致問題を解くことを提案しています. (以下では文献とは若干異なるアルゴリズムを説明しますが,たぶんやっていることは同じです.これは自分的実装の楽さから変更されています.) 以下ではこのアルゴリズムを部分一致DTWと呼びます.

DTWの境界条件を以下のように変更することで,部分一致問題に対応することができます.

  • 境界条件: 下の一列のどこかのマスと上の一列のどこかのマスを通る

これは$Y$の両端が$X$のいずれかのフレームと対応付けられることを意味します. 一方で,$X$の両端のフレームは必ずしも対応付けられる必要はありません.

アルゴリズムとしては,動的計画法の中で以下の2点を変更します.

  • 下一列($i\neq 1, j=1$)の累積コストを$C_{i+1, 1} = C_{i, 1} + \Delta_{i+1, 1}$から$C_{i+1, 1} = \Delta_{i+1, 1}$に変更
  • 終了地点を右上のマスから,$\mathrm{argmin}_i (C_{i, T_y})$へと変更

これらの変更はそれぞれ,始点と終点の変更に対応します. これらの変更によって得られたパスを以下の図に示します.

X = data[:1000:4]
Y = data[1000::4]
def dist(x, y):
    return (x - y)**2

def get_min(m0, m1, m2, i, j):
    if m0 < m1:
        if m0 < m2:
            return i - 1, j, m0
        else:
            return i - 1, j - 1, m2
    else:
        if m1 < m2:
            return i, j - 1, m1
        else:
            return i - 1, j - 1, m2

def partial_dtw(x, y):
    Tx = len(x)
    Ty = len(y)

    C = np.zeros((Tx, Ty))
    B = np.zeros((Tx, Ty, 2), int)

    C[0, 0] = dist(x[0], y[0])
    for i in range(Tx):
        C[i, 0] = dist(x[i], y[0])
        B[i, 0] = [0, 0]

    for j in range(1, Ty):
        C[0, j] = C[0, j - 1] + dist(x[0], y[j])
        B[0, j] = [0, j - 1]

    for i in range(1, Tx):
        for j in range(1, Ty):
            pi, pj, m = get_min(C[i - 1, j],
                                C[i, j - 1],
                                C[i - 1, j - 1],
                                i, j)
            C[i, j] = dist(x[i], y[j]) + m
            B[i, j] = [pi, pj]
    t_end = np.argmin(C[:,-1])
    cost = C[t_end, -1]
    
    path = [[t_end, Ty - 1]]
    i = t_end
    j = Ty - 1

    while (B[i, j][0] != 0 or B[i, j][1] != 0):
        path.append(B[i, j])
        i, j = B[i, j].astype(int)
        
    return np.array(path), cost
path, cost = partial_dtw(X, Y)
D = (np.array(X).reshape(1, -1) - np.array(Y).reshape(-1, 1))**2
plot_path([np.array(path)], X, Y, D)

f:id:ksknw:20191228103046p:plain

for line in path:
    plt.plot(line, [X[line[0]], Y[line[1]]], linewidth=0.8, c="gray")
plt.plot(X)
plt.plot(Y)
plt.plot(path[:,0], X[path[:,0]], c="C2")
plt.show()

f:id:ksknw:20191228103117p:plain

$X$から$Y$と似た部分時系列が一部のみ抽出されていることがわかります.

複数の区間の検出

上で説明した手法を用いることで,長い時系列データ$X$から時系列パターン$Y$に一致する部分時系列を抜き出すことができました. 一方で,この方法では最もDTW距離が小さい1つの区間しか検出することができませんでした. SPRINGでは,$X$の中からパターン$Y$とDTW距離が$\epsilon$以下の部分時系列を複数検出することを提案しています. SPRINGでは特にオンライン設定を対象としており,ここでも同様に$X$が徐々に観測されていくときを考えます.

このとき,DTW距離が$\epsilon$以下の部分時系列の区間が重複している場合が考えられます. このような場合に,全ての区間を列挙することもできますが,可視化などの応用を考えると,全ての区間を検出することは適切ではなさそうです. SPRINGでは重複する区間を全て列挙するのではなく,重複するもののうち,最もDTW距離が小さいもの1つを検出します.

SPRINGでは,この問題を扱うために,各マスについて区間の開始地点を記録,伝搬することを提案しています. つまり,各マスを通る最小コストのパスの区間の開始地点$S_{ij}$をそれぞれのマスで記録します. これは動的計画法の中で,以下のように埋めることができます. $$S_{i,1}=i$$ $$S_{i,j} = S_{i^*, j^*}, (i^*, j^*) = \mathrm{argmin}(C_{i,j+1}, C_{i+1,j}, C_{i,j})$$

この操作によって,$S_{i,T_y}$には,$(i, T_y)$を通るコスト最小のパスの開始位置が記録された状態になります. この$S$を用いることで,以下の手順でDTW距離がしきい値$\epsilon$以下の区間を検出することができます. さらに,ここがよく出来ているところですが,SPRINGでは区間を検出するとき,今後観測される$X$に検出した区間と重複し,かつ,DTW距離が検出した区間よりも小さい区間が存在しないことを保証することができます.

  • 現在の時刻を$t$とします.つまり,$X^{(1)},...,X^{(t)}$と$Y$(全区間)が観測されているとします.
  • 時刻$t$においてDTW距離が最も小さい区間の終了位置を$i^*$,DTW距離を$d_{min}$とします.この区間を候補と呼びます.

このとき,現在の観測時刻$t$において, 全ての$j = 1,...,T_y$について,$C_{t, j} \geq d_{min}$ または,$S_{t, j} \geq i^*$ が成り立つなら,現在の候補を検出結果として出力します.

これを図にすると以下です(突然の手書き図). 今,$t=5$で,$C_{1,T_y}$から$C_{t,T_y}$をチェックして,候補(赤線)が見つかったとします. このとき,SPRINGでは現在の時刻($t$)を通るパス(青や緑)について考えます. ここで,$t=6$以降はまだ観測していないので,これらのパスもDTW距離も計算することができないことに注意してください.

まず,緑色のパスは赤色の候補と区間が重複していないため,候補を検出するか否かを考える際に,考慮する必要がありません. $S_{t, j} \geq i^*$を満たすパスがこのパスに相当します.

次に青色のパスについて考えます. 青色のパスは候補と区間が重複しているため,もし青色のパスのほうがDTW距離が小さい場合は赤色の候補を検出してはいけません. これを判定する基準として,SPRINGでは$C_{t, j}$を用います. DTWでは正のフレーム間コストを考えるので,そのパスが高さ方向$j$の時点で$C_{t, j}$が$d_{min}$を超えた場合は,現在の候補よりも青色のパスのほうが必ずDTW距離が大きくなることがわかるからです.

f:id:ksknw:20191228103403j:plain

実際にSPRINGによって複数の区間を検出した結果を以下に示します.

def spring(x, y, epsilon):
    Tx = len(x)
    Ty = len(y)

    C = np.zeros((Tx, Ty))
    B = np.zeros((Tx, Ty, 2), int)
    S = np.zeros((Tx, Ty), int)

    C[0, 0] = dist(x[0], y[0])

    for j in range(1, Ty):
        C[0, j] = C[0, j - 1] + dist(x[0], y[j])
        B[0, j] = [0, j - 1]
        S[0, j] = S[0, j - 1]
        
    for i in range(1, Tx):
        C[i, 0] = dist(x[i], y[0])
        B[i, 0] = [0, 0]
        S[i, 0] = i
        
        for j in range(1, Ty):
            pi, pj, m = get_min(C[i - 1, j],
                                C[i, j - 1],
                                C[i - 1, j - 1],
                                i, j)
            C[i, j] = dist(x[i], y[j]) + m
            B[i, j] = [pi, pj]
            S[i, j] = S[pi, pj]
            
        imin = np.argmin(C[:(i+1), -1])
        dmin = C[imin, -1]
        
        if dmin > epsilon:
            continue
            
        for j in range(1, Ty):
            if (C[i,j] < dmin) and (S[i, j] < imin):
                break
        else:
            path = [[imin, Ty - 1]]
            temp_i = imin
            temp_j = Ty - 1
            
            while (B[temp_i, temp_j][0] != 0 or B[temp_i, temp_j][1] != 0):
                path.append(B[temp_i, temp_j])
                temp_i, temp_j = B[temp_i, temp_j].astype(int)
                
            C[S <= imin] = 100000000
            yield np.array(path), dmin
pathes = []
for path, cost in spring(X, Y, 80):
    for line in path:
        plt.plot(line, [X[line[0]], Y[line[1]]], linewidth=0.8, c="gray")
    plt.plot(X)
    plt.plot(Y)
    plt.plot(path[:,0], X[path[:,0]], C="C2")
    plt.show()
    pathes.append(path)

f:id:ksknw:20191228103137p:plain

f:id:ksknw:20191228103151p:plain

path, cost = partial_dtw(X, Y)
D = (np.array(X).reshape(1, -1) - np.array(Y).reshape(-1, 1))**2
plot_path(pathes, X, Y, D)

f:id:ksknw:20191228103212p:plain

まとめ

ここでは,時系列データから特定の時系列パターンに一致した区間を検出する問題について考え,これを解くアルゴリズムの1つであるSPRINGをpythonで実装しました. 部分一致DTWは応用先の広いアルゴリズムで,最近ではDTWNet(ニューラルネットワークconvolutionカーネルの代わりにDTWを使う)に使われていました. 今回はpythonで実装したので遅いです.せっかくのオンライン設定用のアルゴリズムなので,気が向いたらjitとかjuliaとか使って高速化したいです.

参考

Blender + Pythonでポイントクラウドを可視化する

はじめに

ポイントクラウドデータをいい感じに可視化したい. matplotlibでも3次元データのscatterを描くことができるが,以下のような感じでいまいちな見た目になってしまう.

f:id:ksknw:20191029190640p:plain

もうちょっといい感じの図が描きたい.たとえはこんな感じのやつ. 調べてみると,blenderpythonが使えるらしいので,blenderなんもわからんけどやってみる. データはここで使われているShapenetの一部を用いた.

環境

ちょっと調べてでてくるスクリプトは動かないことが多かった(特にライト周り). 今使っているバージョン付近で大きめのアップデートがあったのか,blender自体がそういうアップデートの方針なのかわからないが,使っているバージョンによっては以下のスクリプトは動かないので注意.

Blenderのインストール

blenderはaptでもインストールできるが,最新版ではなかったのでsnapdというやつを使ってみることにした. 一般にパッケージ管理ソフトを複数混ぜるのは良くない気がするので,あまり良くない方法かもしれない.

sudo apt install snapd 
sudo snap install blender

terminalから

blender

で起動できる.

スクリプト

Blenderの中のScriptingタブからテキストエディタを開くことができる. Blenderから実行するpythonからはbpyというモジュールを使うことができ,これを使ってblenderの中のオブジェクトやカメラなどを操作できる. 全体はここに置いた.

オブジェクトの作成

ポイントクラウドを可視化するために,ポイントクラウドの各点に球を配置する. これは以下の手順でできる.

  1. 各位置にプリミティブを配置する.
  2. プリミティブの大きさを変更する.
  3. マテリアルを設定し,色を指定する.

プリミティブにはcubeやconeなど色々あるが,今回はuv_sphereを使う.

bpy.ops.mesh.primitive_uv_sphere_add(location=(x, y, z))

大きさの変更はオブジェクトを選択してスケールを変更することでできる. このあたりはかなりGUIpythonで操作している感じがある. bpy.context.object で作成したばかりのオブジェクト(選択済みになっているもの)を取得できる. 選択したオブジェクトのスケールを変更することで,球の大きさを変更できる.

sphere = bpy.context.object
sphere.scale = (0.02, 0.02, 0.02)

最後にマテリアルを追加することで,球の色を変更する. この機能を使うと,金属にするとか,水滴にするとかができる(と思う)が,今回は色を変更するだけ.

mat = bpy.data.materials.new("BLUE")
mat.diffuse_color = (.33, .43, .95, 1) # RGBA
sphere.data.materials.append(mat)

照明の配置

前述の通り、検索して出てくる多くのスクリプトが動かなかった。 これを参考にして,4方位から照らすようにした.

light_data = bpy.data.lights.new(name="light", type='AREA')
light_data.energy = 500
light_object = bpy.data.objects.new(name="light", object_data=light_data)
bpy.context.collection.objects.link(light_object)
bpy.context.view_layer.objects.active = light_object
light_object.location = loc
light_object.rotation_euler = rot

その他

カメラの配置はblender上で頑張って操作して決めた. renderのオプションなどは呪文として書いた.よくわかってない. 雰囲気的に床があったほうが良さそうだったので,適当に巨大な板を下に置いた.

bpy.ops.mesh.primitive_plane_add(location=(0, 0, -1))
plane = bpy.context.object
plane.scale = (100, 100, 1)

ということで以下のようなものを書いた.

実行

Blenderの中から実行しても良いし,コマンドラインから実行してもよい. jupyter notebookの中からも実行できるらしいが,まだやっていない (おそらく/snap/blender/33/2.80/python以下にipythonを入れればできそう).

コマンドラインから実行する場合は以下.

blender --python visualize_points_blender.py

--backgroundオプションをつけるとbackgroundで実行できるらしいが,セグフォでうまくできなかった.

実行すると以下のような図を得る.

f:id:ksknw:20191029190609p:plain

まとめ

Blenderpythonから使ってポイントクラウドを可視化した. GPUが弱いからか,1000点ぐらいのポイントクラウドを可視化するのに3分ぐらいかかっている. いい感じの図を作るセンスがほしい.

参考