fastai tutorial:GPU-accelerated mean shift in pytorchをやる

今日は、fastaiのpytorch gpuを使ったチュートリアルをやる。途中までは前回の続きなので、コードだけを貼り付けていく。今日のチュートリアル学習部分は、面倒くさかったので、翻訳の方はかなりいい加減に仕上がっている。

cd git/fastai
/home/workspace/git/fastai
%matplotlib inline
import math
import numpy as np
import matplotlib.pyplot as plt
import operator
import torch
from fastai.core import *
n_clusters=6
n_samples =250
centroids = np.random.uniform(-35, 35, (n_clusters, 2))
slices = [np.random.multivariate_normal(centroids[i], np.diag([5., 5.]), n_samples)
           for i in range(n_clusters)]
data = np.concatenate(slices).astype(np.float32)
def plot_data(centroids, data, n_samples):
    colour = plt.cm.rainbow(np.linspace(0,1,len(centroids)))
    for i, centroid in enumerate(centroids):
        samples = data[i*n_samples:(i+1)*n_samples]
        plt.rcParams['figure.figsize'] = 12, 8
        plt.rcParams["font.size"] = "17"
        plt.scatter(samples[:,0], samples[:,1], c=colour[i], s=1)
        plt.plot(centroid[0], centroid[1], markersize=10, marker="x", color='k', mew=5)
        plt.plot(centroid[0], centroid[1], markersize=5, marker="x", color='m', mew=2)
スポンサーリンク

Mean shift

クラスタリングアルゴリズムに出くわすほとんどの人々は、k-meansについて学んでいる。Mean shift clusteringは、新しくあまり知られていないアプローチだが、いくつかの重要な利点を持っている。

  • 事前にクラスター数を選ぶ必要がない代わりに、自動で簡単に選べるバンド幅を指定するだけでいい。
  • それがどんな形のクラスタも取り扱える一方で、k-meansは(special extensionsを使うことなく)、クラスタがほぼボール形である必要がある。

アルゴリズムは以下の通り。

  • サンプルXの各データポイントxに対し、そのポイントxとX中の他の全てのポイント間の距離を見つけ出す。
  • Xまでのそのポイントの距離のGaussian kernel(ガウスカーネル)を使用して、X中の各ポイントに対する重みを作成する
    • この重み付け手法はxからさらにポイントを遠ざけるペナルティを与える。
    • 重みがゼロになる速度は、standard deviation of the Gaussian(ガウス分布の標準偏差)であるバンド幅によって決定される。
  • 前のステップを基に荷重されたXの他の全てのポイントのweighted average(加重平均)としてxを更新する。

これが、互いに近いポイントを、それらが隣り同士になるまで徐々に近付けていく。

ということで、過去に習っているだろうガウスカーネルの定義を以下に示す。

from numpy import exp, sqrt, array, abs
def gaussian(d, bw): return exp(-0.5*((d/bw))**2) / (bw * math.sqrt(2*math.pi))

全ての距離はプラスなので、ガウス分布の右手側だけを使用する。異なるバンド幅(bw)に対してどのように見えるのかを以下に示す。

plt.rcParams['figure.figsize'] = 12, 8
plt.rcParams["font.size"] = "17"
x=np.linspace(0,5)
fig, ax = plt.subplots()
ax.plot(x, gaussian(x, 1), label='bw=1');
ax.plot(x, gaussian(x, 2.5), label='bw=2.5')
ax.legend();

ポイント間の距離を計算する必要がある。下に使用する関数を定義する。

def distance(x, X): return sqrt(((x-X)**2).sum(1))

関数を試す(この関数の仕組みについては後程説明する)。

d = distance(array([2,3]), array([[1,2],[2,3],[-1,1]])); d
array([1.41421, 0.     , 3.60555])

このケースでどの程度の重みを得るかを見るために、ガウス関数に距離をフィードすることができる。

gaussian(d, 2.5)
array([0.13598, 0.15958, 0.0564 ])

アルゴリズムのsingle iteration(単一反復)を定義するために、これらのステップをまとめることができる。

def meanshift_inner(x, X, bandwidth):
    # Find distance from point x to every other point in X
    dist = distance(x, X)

    # Use gaussian to turn into array of weights    
    weight = gaussian(dist, bandwidth)
    # Weighted sum (see next section for details)
    return (weight[:,None]*X).sum(0) / weight.sum()
    
def meanshift_iter(X, bandwidth=2.5):
    return np.array([meanshift_inner(x, X, bandwidth) for x in X])
X=meanshift_iter(data)

結果は、望み通り、全ての点が、(アルゴリズムがクラスタの存在位置を実際には知らないのだが)それらの真のクラスタ中心に近付くことを示している。

plot_data(centroids, X, n_samples)

これを数回繰り返せば、クラスタをより正確にすることができる。

def meanshift(X, it=0, max_it=5, bandwidth=2.5, eps=0.000001):
    # perform meanshift once
    new_X = meanshift_iter(X, bandwidth=bandwidth)
    # if we're above the max number of allowed iters
    # or if our approximations have converged
    if it >= max_it or abs(X-new_X).sum()/abs(X.sum()) < eps:
        return new_X
    else:
        return meanshift(new_X, it+1, max_it, bandwidth, eps)
%time X=meanshift(data)
CPU times: user 884 ms, sys: 0 ns, total: 884 ms
Wall time: 882 ms

mean shift clusteringが、オリジナルのクラスタリングをほぼ再現することを見ることができる。一つの例外が、非常に近いクラスタにおいてだが、もし、それらを本当に差別化したい場合は、バンド幅を下げてやればいい。

特筆すべきは、このアルゴリズムが、そこにいくつのクラスタがあるのかを教えることなしに、オリジナルクラスタをほぼ再現することだ(下のチャートの中で、セントロイドをやや右にオフセットしているが、そうしなければ、それらが今は互いに重なり合っているので、点を視認することができなくなる)。

plot_data(centroids+2, X, n_samples)

GPU-accelerated mean shift in pytorch

pytorchの一つの利点が、numpyに似ているということだろう。例えば、実際に、ガウス分布と距離、meanshift_iterの定義は同じなので、単純に、2つのnumpy関数の代わりのPyTorch実装をインポートすればいい。

from torch import exp, sqrt

次に、これまでと全く同じコードを使うが、最初に、numpyアレイをGPU PyTorchテンソルに変換する。

def meanshift_iter_torch(X, bandwidth=2.5):
    out = torch.stack([meanshift_inner(x, X, bandwidth) for x in X], 0)
    return to_gpu(out.cuda())

def meanshift_torch(X_torch, it=0, max_it=5, bandwidth=2.5, eps=0.000001):
    new_X = meanshift_iter_torch(X_torch, bandwidth=bandwidth)
    if it >= max_it or abs(X_torch-new_X).sum()/abs(X_torch.sum()) < eps:
        return new_X
    else:
        return meanshift_torch(new_X, it+1, max_it, bandwidth, eps)
X_torch = to_gpu(torch.from_numpy(X))
%time X = meanshift_torch(X_torch).cpu().numpy()
plot_data(centroids+2, X, n_samples)
CPU times: user 157 ms, sys: 24 ms, total: 181 ms
Wall time: 169 ms

うまくいったが、この実装は処理速度が遅い。何故なら、各ループが新しいcudaカーネルを立ち上げ、それが全体的にアルゴリズムの速度を遅くしているからだ。さらに、各ループは、GPUの全スレッドを埋めるための十分な処理能力を有していない。GPUを効果的に使うには、バッチデータを一度に処理する必要がある。

GPU batched algorithm

バッチデータを処理するには、関数のバッチ化バージョンが必要になる。バッチを取り扱うdistance()のバッチ化バージョンを下に示す。

def distance_b(a,b): return sqrt(((a[None,:] - b[:,None]) ** 2).sum(2))
a=torch.rand(2,2)
b=torch.rand(3,2)
distance_b(b, a)
 0.4822  0.7838  0.5279
 0.2641  0.7058  0.1938
[torch.FloatTensor of size 2x3]

distance_b()の2つのパラメータが、2つの異なる場所(a[None,:] and b[:,None])にどのようにしてunit axisを付け加えたかに留意する。これは、あらゆる関数に対する外積の概念を効果的に一般化する便利な手法である。今回のケースでは、(バッチ)aの全ての点から(全データセット)bの全ての点までの距離を得るのに使用している。

適切な距離関数を得たので、バッチデータを取り扱えるようにするために、meanshift関数にいくつかのマイナーアップデートを加えることができる。

def meanshift_gpu(X, it=0, max_it=5, bandwidth=2.5, eps=0.000001):
    weights = gaussian(distance_b(X, X), bandwidth)
    num = (weights[:,:,None] * X).sum(1)
    X_new = num / weights.sum(1)[:,None]
    
    if it >= max_it or abs(X_new - X).sum()/abs(x.sum()) < eps:
        return X_new
    else:
        return meanshift_gpu(X_new, it+1, max_it, bandwidth, eps)

各ループは、それでもなお、新しいcuda kernelを立ち上げる必要があるが、今回はより少ないループで、なおかつ、点をまとめて更新することから得られる速度アップが、それを補うのには十分である。

X_torch = to_gpu(torch.from_numpy(data))
%time X = meanshift_gpu(X_torch).cpu().numpy()
CPU times: user 31.5 ms, sys: 28.2 ms, total: 59.7 ms
Wall time: 58.3 ms

That’s more like it!(よし!)、882msから58.3msに短縮できた。なんと、15.12倍もスピードアップできた!そして、答えも正しい!

plot_data(centroids+2, X, n_samples)