微分可能なソートの実装

はじめに

この論文を読んでおもしろいなと思ったので実装して試してみる。 ここではイントロは飛ばしてやり方の説明だけ。

論文の基本的なアイディアは1次元の最適輸送問題をSinkhornで解いて、得られた輸送計画からソートの結果を推定するという感じ

Preliminary

以下では微分可能なソートに使うSinkhornアルゴリズムと1次元の最適輸送問題についてそれぞれ説明する。

Sinkhornアルゴリズム

2つの点群(離散分布)の間の最適輸送について考える。 簡単のために点の数はともに$n$とし、各点の重みはすべて等しい($1/n$)とする。 このような2つの点群を$X = \{x_1, \ldots, x_n\}$と$Y = \{y_1, \ldots, y_n\}$と書く。 このとき、最適輸送は点同士の対応付けを、対応付けられた点間の距離の合計が最小になるように求める。 式で書くと以下。ただし、ground metric (点同士の距離)として二乗誤差$C$ ($C_{ij} = \|x_i - y_j\|^2_2$)を使う。

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt$$

ここで、$P$は対応付けを表す行列。$U$は対応付けの条件を表しており、$U=\{P\in [0,1]^{n\times n} | P \mathbf 1_n = 1/n, P^\top \mathbf 1_n = 1/n \}$である。例えば$x_1$と$y_2$が対応づけられる場合は$P_{12}=1/n$となる。この最適化問題線形計画問題なので単体法などで解くことができる。

単体法などは微分可能なアルゴリズムではないが、最適化問題を以下のように少し変更することで、微分可能なSinkhorn iterationという方法によって最適輸送問題を解くことができることが知られている。

$$\min_{P\in U(a,b)} \left \lt P, C \right \gt - H(P)$$

ここで,$H(P)$は輸送行列$P$のエントロピーを表し,$H(P) = - \sum_{i,j} P_{ij} \log P_{ij}$である。 また、$\epsilon$はエントロピー正則化の強さを表す。 Sinkhornの具体的な計算については以前書いたものを参考。

ksknw.hatenablog.com

1次元の最適輸送問題

2つの点群がそれぞれ1次元上に存在するときを考える。 見やすさのために$y$軸方向に少しずらしてプロットする。青点が$x_i$、オレンジ点が$y_j$を表す ($x_i, y_j \in \mathbb R$)。

import numpy as np
import pylab as plt
import torch
import torch.nn as nn

torch.manual_seed(1234)

n = 5

X = torch.randn(n)
Y = torch.randn(n)

plt.scatter(X, torch.zeros(n))
plt.scatter(Y, torch.ones(n))
plt.show()

1次元のときの最適輸送問題は簡単に解ける。点のうち左側にあるもの同士を対応付け、次に左から2点目を対応付け、…と左から順に対応付ければ最適解を得ることができる。 これは2つの点群をそれぞれソートする操作に対応する。

P = torch.zeros(n, n)
P[torch.sort(X)[1], torch.sort(Y)[1]] = 1/n

def plot(X, Y, P):
    plt.scatter(X, torch.zeros(n))
    plt.scatter(Y, torch.ones(n))
    for i,p_i in enumerate(P):
        for j,p_ij in enumerate(p_i):
            plt.plot([X[i], Y[j]], [0, 1],
                     alpha=p_ij.numpy()*(n-1),
                     color="gray")
    plt.show()

plot(X, Y, P)

実際に点群をソートするだけで1次元の最適輸送問題を解くことができた。

1次元の最適輸送は多くの場合ソートによって解かれる。これはSinkhornよりもソートのほうが計算量が少なくて済むからである。 しかし、当然Sinkhornでも1次元の最適輸送問題を解くことができる。

class Sinkhorn(nn.Module):
    def __init__(self, epsilon, steps):
        self.epsilon = epsilon
        self.steps = steps
        super(Sinkhorn, self).__init__()

    def softmin(self, A, log_P):
        ret = - self.epsilon * torch.logsumexp(-A / self.epsilon, dim=1)
        return ret

    def forward(self, C, a, b):
        Cmax = C.max()
        C = C/Cmax

        log_a = torch.log(a)
        log_b = torch.log(b)
        log_v = torch.zeros_like(log_b)

        P = a[:, None] @ b[None]
        log_P = torch.log(P)

        previous_cost = 100000
        for step_i in range(self.steps):
            log_u = self.epsilon * log_a + \
                self.softmin(C - log_v.unsqueeze(0), log_P)

            log_v = self.epsilon * log_b + \
                self.softmin(C.transpose(1, 0) - log_u.unsqueeze(0),
                             log_P.transpose(1, 0))
            log_P = (log_u.unsqueeze(1) - C +
                     log_v.unsqueeze(0)) / self.epsilon

            P = torch.exp(log_P)
            cost = torch.sum(P * C) * Cmax

            if (abs(cost.detach().tolist() - previous_cost) < 1e-16 * Cmax):
                break
            previous_cost = cost.detach().tolist()

        return cost, P

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Y[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X, Y, P)

実際にSinkhornを用いた場合でもソートした場合と似たような解を得ることができた。

微分可能なソート

1次元の最適輸送問題をSinkhornで解くことで微分可能なソートを実装する。 やることは簡単で、入力の一方(ここでは$X$)をソートしたいデータ、もう一方($Y$)をソート済みの1次元点群としてSinkhornをやるだけでいい。 $Y$はソート済みの点群であれば何でもいいが、ここでは特に[0,1]の範囲に等間隔に存在する点群とした。

Y = torch.linspace(0, 1, n)
plt.scatter(X, torch.zeros(n))
plt.scatter(Y, torch.ones(n))

このとき、Sinkhornを使って点の対応関係を求めると以下のようになる。

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Y[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X, Y, P)

ここで、$Y$はソート済みであるので、Yの1つ目の点($y_1$)と対応づいた$X$側の点は、$X$の中で最も小さい値を持っていることがわかる (1次元の最適輸送は両方の点をソートした場合と一致するので)。同様に$y_i$と対応づいた点は$i$番目に小さい値を持っている。これは以下のように$P$の列方向のargmaxを見ると確認できる。

P.argmax(dim=1), torch.sort(X)[1]
(tensor([2, 4, 0, 3, 1]), tensor([2, 4, 0, 3, 1]))

$X$のソート済みの値はこのようにargmaxを計算しても解けるけど、以下のように輸送行列ともとのデータとの掛け算でも計算できる。この方法でやればSinkhorn + 掛け算なので微分可能になっていることに注意。

P.t() @ X * n
tensor([-1.0106, -0.6100,  0.0467,  0.2172,  0.4024])

普通にソートすると以下。

torch.sort(X)[0]
tensor([-1.0115, -0.6123,  0.0461,  0.2167,  0.4024])

Sinkhornで求めたソート前後の値は厳密には一致していない。これはエントロピー正則化の影響で$P$が厳密には{0, 1/n}の値に張り付いていないことが原因であると思われる。実際に値を見てみても1と0だけになっているわけではないことがわかる。

P
tensor([[7.4395e-42, 6.9331e-04, 1.9931e-01, 4.0038e-08, 1.3129e-24],
        [0.0000e+00, 4.6160e-38, 1.7486e-16, 4.6286e-04, 2.0000e-01],
        [1.9954e-01, 3.2266e-20, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 1.6854e-15, 6.9372e-04, 1.9954e-01, 9.3685e-09],
        [4.6154e-04, 1.9931e-01, 2.6548e-34, 0.0000e+00, 0.0000e+00]])

せっかくなのでnn.Moduleとして書くと以下のような感じ

class SoftSort(nn.Module):
    def __init__(self, epsilon, steps):
        super(SoftSort, self).__init__()
        self.sinkhorn = Sinkhorn(epsilon=epsilon, 
                                steps=steps)
    def forward(self, x):
        anchors = torch.linspace(0,1,len(x))
        C = (x[:, None] - anchors[None])**2
        a = torch.ones(len(x))/len(x)
        b = torch.ones(len(anchors))/len(anchors)
        _, P = self.sinkhorn(C, a, b)
        return x @ P / a

勾配を使った最適化

Sinkhornを使ってソートが解けることを示した。このアルゴリズムの利点は、ソート後の値がSinkhorn+行列の掛け算のみで得られるため勾配が容易に計算できることである。せっかくなので、以下のような適当な問題を作って入力の$X$を最適化することを考える。

$X$のソート後の値がある点群$Z$と等しくなるように$X$を最適化。

$$\min_{X} \|Z - \text{sort}(X)\|^2$$

$Z$として、[-1, 0]に等間隔に配置された点を考えることにする。 (sortの部分だけbackprop切っても似たような解が得られると思うので例として微妙だけど簡単なのでこれをやる)

from torch.optim import SGD
soft_sort = SoftSort(epsilon=1e-2, steps=10000)
X.requires_grad = True
Z = torch.linspace(-1, 1, n)

optimizer = SGD([X], lr=1e-2)
hist_loss = []
for i in range(300):
    soft_sorted_x = soft_sort(X)
    loss = ((Z - soft_sorted_x)**2).sum()
    loss.backward()
    hist_loss.append(loss.detach().tolist())
    optimizer.step()
    optimizer.zero_grad()
torch.sort(X)[0], Z, soft_sorted_x
(tensor([-1.0200e+00, -5.0110e-01, -3.2722e-04,  5.0038e-01,  1.0193e+00],
        grad_fn=<SortBackward>),
 tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000]),
 tensor([-9.9967e-01, -5.0004e-01, -3.2446e-04,  4.9931e-01,  9.9897e-01],
        grad_fn=<DivBackward0>))
plt.plot(hist_loss)

sinkhorn = Sinkhorn(epsilon=1e-3, steps=10000)
_, P = sinkhorn(((X[:, None] - Z[None])**2), a=torch.ones(n)/n, b=torch.ones(n)/n)
plot(X.detach(), Z, P.detach())

ここで今回考えた最適化問題はソート後の$X$が等間隔に並ぶようにしただけであり、$X$の値自体はソートされていない状態のままであることに注意。

X
tensor([-3.2722e-04,  1.0193e+00, -1.0200e+00,  5.0038e-01, -5.0110e-01],
       requires_grad=True)

まとめ

Sinkhornを使ってソート済みの1次元点列との間の最適輸送問題を解くことで、微分可能な形でソートを計算できるという内容の論文を紹介した。 論文ではもう少し意味ありそうな実験をやっている他、微分可能なquantileの計算もやっている(ソート後の値の適切なインデックスのものを取ってくるだけでいい)。

Sinkhornも1次元の最適輸送問題がソートで解けることも知っていたのに、最適輸送で微分可能なソートができると最初に聞いたときは全然この方法を思いつかなかったので悔しい。

参考

Differentiable Ranking and Sorting using Optimal Transport