cython, python, numba.jitの速度の比較パートⅡ。コードはここから拝借した。
スポンサーリンク
cython, python, numba.jitの速度を比較¶
マンデルブロー描画(python編)¶
import numpy as np
import numba
import matplotlib.pyplot as plt
import cython
def py_julia_fractal(z_re, z_im, j):
for m in range(len(z_re)):
for n in range(len(z_im)):
z = z_re[m] + 1j * z_im[n]
for t in range(256):
z = z ** 2 - 0.05 + 0.68j
if np.abs(z) > 2.0:
#if (z.real * z.real + z.imag * z.imag) > 4.0: # a bit faster
j[m, n] = t
break
N = 1024
j = np.zeros((N, N), np.int64)
z_real = np.linspace(-1.5, 1.5, N)
z_imag = np.linspace(-1.5, 1.5, N)
import time
start_time = time.time()
py_julia_fractal(z_real, z_imag, j)
fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
マンデルブロ集合描画(numba.jit編)¶
jit_julia_fractal = numba.jit(nopython=True)(py_julia_fractal)
start_time = time.time()
jit_julia_fractal(z_real, z_imag, j)
fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
40.08418655395508/0.4300670623779297
マンデルブロー集合描画(cython編)¶
%load_ext cython
%%cython -a
cimport numpy
cimport cython
ctypedef numpy.int64_t ITYPE_t
ctypedef numpy.float64_t FTYPE_t
cpdef inline double abs2(double complex z):
return z.real * z.real + z.imag * z.imag
@cython.boundscheck(False)
@cython.wraparound(False)
def cy_julia_fractal(numpy.ndarray[FTYPE_t, ndim=1] z_re,
numpy.ndarray[FTYPE_t, ndim=1] z_im,
numpy.ndarray[ITYPE_t, ndim=2] j):
cdef int m, n, t, M = z_re.size, N = z_im.size
cdef double complex z
for m in range(M):
for n in range(N):
z = z_re[m] + 1.0j * z_im[n]
for t in range(256):
z = z ** 2 - 0.05 + 0.68j
if abs2(z) > 4.0:
j[m, n] = t
break
start_time = time.time()
cy_julia_fractal(z_real, z_imag, j)
fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
0.4300670623779297/0.14763426780700684
cythonはnumba.jitの約3倍高速でマンデルブロ集合を描画した。やはりcythonは最速という結果になった。
スポンサーリンク
cudaを使ったマンデルブロ描画¶
ここから引っ張ってきたコードで描いたマンデルブロ集合描画(cuda編)
from pylab import imshow, show
from timeit import default_timer as timer
from numba import cuda
from numba import *
@cuda.jit(device=True)
def mandel(x, y, max_iters):
"""
Given the real and imaginary parts of a complex number,
determine if it is a candidate for membership in the Mandelbrot
set given a fixed number of iterations.
"""
c = complex(x, y)
z = 0.0j
for i in range(max_iters):
z = z*z + c
if (z.real*z.real + z.imag*z.imag) >= 4:
return i
return max_iters
@cuda.jit
def mandel_kernel(min_x, max_x, min_y, max_y, image, iters):
height = image.shape[0]
width = image.shape[1]
pixel_size_x = (max_x - min_x) / width
pixel_size_y = (max_y - min_y) / height
startX = cuda.blockDim.x * cuda.blockIdx.x + cuda.threadIdx.x
startY = cuda.blockDim.y * cuda.blockIdx.y + cuda.threadIdx.y
gridX = cuda.gridDim.x * cuda.blockDim.x;
gridY = cuda.gridDim.y * cuda.blockDim.y;
for x in range(startX, width, gridX):
real = min_x + x * pixel_size_x
for y in range(startY, height, gridY):
imag = min_y + y * pixel_size_y
image[y, x] = mandel(real, imag, iters)
gimage = np.zeros((1024, 1536), dtype = np.uint8)
blockdim = (32, 8)
griddim = (32,16)
start = timer()
d_image = cuda.to_device(gimage)
mandel_kernel[griddim, blockdim](-2.0, 1.0, -1.0, 1.0, d_image, 30)
d_image.to_host()
dt = timer() - start
print ("Mandelbrot created on GPU in %f s" % dt)
imshow(gimage)
スポンサーリンク
スポンサーリンク