# Cython / Python / Numba.jitの速度比較パートⅡ

cython, python, numba.jitの速度の比較パートⅡ。コードはここから拝借した。

スポンサーリンク

## cython, python, numba.jitの速度を比較¶

### マンデルブロー描画(python編)¶

import numpy as np
import numba
import matplotlib.pyplot as plt
import cython

def py_julia_fractal(z_re, z_im, j):
for m in range(len(z_re)):
for n in range(len(z_im)):
z = z_re[m] + 1j * z_im[n]
for t in range(256):
z = z ** 2 - 0.05 + 0.68j
if np.abs(z) > 2.0:
#if (z.real * z.real + z.imag * z.imag) > 4.0:  # a bit faster
j[m, n] = t
break

N = 1024
j = np.zeros((N, N), np.int64)
z_real = np.linspace(-1.5, 1.5, N)
z_imag = np.linspace(-1.5, 1.5, N)

import time

start_time = time.time()

py_julia_fractal(z_real, z_imag, j)

fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))

Elapsed time: 40.08418655395508


### マンデルブロ集合描画(numba.jit編)¶

jit_julia_fractal = numba.jit(nopython=True)(py_julia_fractal)

start_time = time.time()

jit_julia_fractal(z_real, z_imag, j)

fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))

Elapsed time: 0.4300670623779297

40.08418655395508/0.4300670623779297

93.20450241485904

numba.jitはpythonの93倍高速にマンデルブロー集合を描画した。

### マンデルブロー集合描画(cython編)¶

%load_ext cython

%%cython -a
cimport numpy
cimport cython

ctypedef numpy.int64_t ITYPE_t
ctypedef numpy.float64_t FTYPE_t

cpdef inline double abs2(double complex z):
return z.real * z.real + z.imag * z.imag

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_julia_fractal(numpy.ndarray[FTYPE_t, ndim=1] z_re,
numpy.ndarray[FTYPE_t, ndim=1] z_im,
numpy.ndarray[ITYPE_t, ndim=2] j):
cdef int m, n, t, M = z_re.size, N = z_im.size
cdef double complex z
for m in range(M):
for n in range(N):
z = z_re[m] + 1.0j * z_im[n]
for t in range(256):
z = z ** 2 - 0.05 + 0.68j
if abs2(z) > 4.0:
j[m, n] = t
break

start_time = time.time()

cy_julia_fractal(z_real, z_imag, j)

fig, ax = plt.subplots(figsize=(14, 14))
ax.imshow(j, cmap=plt.cm.RdBu_r,
extent=[-1.5, 1.5, -1.5, 1.5])
ax.set_xlabel("$\mathrm{Re}(z)$", fontsize=18)
ax.set_ylabel("$\mathrm{Im}(z)$", fontsize=18)
fig.tight_layout()
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))

Elapsed time: 0.14763426780700684

0.4300670623779297/0.14763426780700684

2.9130571700340586

cythonはnumba.jitの約3倍高速でマンデルブロ集合を描画した。やはりcythonは最速という結果になった。

スポンサーリンク

## cudaを使ったマンデルブロ描画¶

ここから引っ張ってきたコードで描いたマンデルブロ集合描画(cuda編)

from pylab import imshow, show
from timeit import default_timer as timer
from numba import cuda
from numba import *

@cuda.jit(device=True)
def mandel(x, y, max_iters):
"""
Given the real and imaginary parts of a complex number,
determine if it is a candidate for membership in the Mandelbrot
set given a fixed number of iterations.
"""
c = complex(x, y)
z = 0.0j
for i in range(max_iters):
z = z*z + c
if (z.real*z.real + z.imag*z.imag) >= 4:
return i

return max_iters

@cuda.jit
def mandel_kernel(min_x, max_x, min_y, max_y, image, iters):
height = image.shape[0]
width = image.shape[1]

pixel_size_x = (max_x - min_x) / width
pixel_size_y = (max_y - min_y) / height

startX = cuda.blockDim.x * cuda.blockIdx.x + cuda.threadIdx.x
startY = cuda.blockDim.y * cuda.blockIdx.y + cuda.threadIdx.y
gridX = cuda.gridDim.x * cuda.blockDim.x;
gridY = cuda.gridDim.y * cuda.blockDim.y;

for x in range(startX, width, gridX):
real = min_x + x * pixel_size_x
for y in range(startY, height, gridY):
imag = min_y + y * pixel_size_y
image[y, x] = mandel(real, imag, iters)

gimage = np.zeros((1024, 1536), dtype = np.uint8)
blockdim = (32, 8)
griddim = (32,16)

start = timer()
d_image = cuda.to_device(gimage)
mandel_kernel[griddim, blockdim](-2.0, 1.0, -1.0, 1.0, d_image, 30)
d_image.to_host()
dt = timer() - start

print ("Mandelbrot created on GPU in %f s" % dt)

imshow(gimage)

Mandelbrot created on GPU in 0.174355 s

<matplotlib.image.AxesImage at 0x7f067015bef0>
スポンサーリンク
スポンサーリンク

フォローする