今日はこのサイトを参考にして、numpy.exp, cumath.exp, ElementwiseKernel, SourceModuleの速度比較をする。numpy.expというのは、”Calculate the exponential of all elements in the input array”「入力配列の全要素のネイピア数のべき乗(指数関数)を算出する。」関数のことらしい。
スポンサーリンク
速度比較用コードを作成¶
import pycuda.autoinit
import pycuda.driver as drv
from pycuda import gpuarray, cumath
from pycuda.elementwise import ElementwiseKernel
from pycuda.compiler import SourceModule
import numpy as np
start = drv.Event()
end = drv.Event()
kernel = ElementwiseKernel(
"double *a,double *b",
"b[i] = exp(a[i]);")
mod = SourceModule("""
__global__ void gexp(double *a,double *b,int n)
{
int i = threadIdx.x + blockIdx.x * blockDim.x;
while (i < n) {
b[i] = exp(a[i]);
i += blockDim.x * gridDim.x;
}
}
""")
knl = mod.get_function("gexp")
results = []
for N in [10**4, 10**5, 10**6, 10**7, 10**8]:
a = 2*np.ones(N,dtype=np.float64)
a_gpu = gpuarray.to_gpu(a)
b_gpu = gpuarray.zeros_like(a_gpu)
c_gpu = gpuarray.to_gpu(a)
d_gpu = gpuarray.zeros_like(c_gpu)
start.record()
np.exp(a)
end.record()
end.synchronize()
sec1 = start.time_till(end)*1e-3
print ("Numpy",sec1)
start.record() # start timing
kernel(a_gpu,b_gpu)
end.record() # end timing
end.synchronize()
sec2 = start.time_till(end)*1e-3
print ("Kernel",sec2)
print (np.allclose(np.exp(a),b_gpu.get()))
start.record() # start timing
knl(c_gpu,d_gpu,block=(1024,1,1),grid=(N//1024,1,1))
end.record() # end timing
end.synchronize()
sec3 = start.time_till(end)*1e-3
print ("knl",sec3)
print (np.allclose(np.exp(a),d_gpu.get()))
start.record()
cumath.exp(a_gpu)
end.record()
end.synchronize()
sec4 = start.time_till(end)*1e-3
print ("Cumath", sec4)
#print (np.allclose(np.exp(a),cumath.exp(a_gpu))
results.append([N,sec1,sec2,sec3,sec4])
スポンサーリンク
結果をグラフ化するコードを作成¶
import matplotlib.pyplot as plt
results = np.array(results)
legends = []
nH = results[:5, 0:1]
rows = results[:5,1:6]
plt.semilogx(nH,rows, 'o-')
legends += ['' + s for s in ['numpy','wise','source','cumath']]
plt.rcParams['figure.figsize'] = 18, 10
plt.rcParams["font.size"] = "20"
plt.ylabel('Seconds')
plt.xlabel('Value of N')
plt.legend(legends);
スポンサーリンク
スポンサーリンク