GPUとCPUのMandelbrot set描画の速度比較をしてみた。GPU versionにはnumba.cuda.jitを使用した。コードはここから拝借させてもらった。一日も早くこういうコードを自作してみたいものだ。
スポンサーリンク
マンデルブロ描画(CPU version)¶
def escape_time(p, maxtime):
"""Perform the Mandelbrot iteration until it's clear that p diverges
or the maximum number of iterations has been reached.
"""
z = 0j
for i in range(maxtime):
z = z ** 2 + p
if abs(z) > 2:
return i
return maxtime
import numpy
import matplotlib.pyplot as plt
import time
start_time = time.time()
maxiter = 300
rlim = (-2.2, 1.5)
ilim = (-1.5, 1.5)
nx = 2048
ny = 2048
dx = (rlim[1] - rlim[0]) / nx
dy = (ilim[1] - ilim[0]) / ny
M = numpy.zeros((ny, nx), dtype=int)
for i in range(ny):
for j in range(nx):
p = rlim[0] + j * dx + (ilim[0] + i * dy) * 1j
M[i, j] = escape_time(p, maxiter)
plt.imshow(M, interpolation="nearest")
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
スポンサーリンク
マンデルブロー描画(GPU version)¶
import numpy
from numba import cuda
import matplotlib.pyplot as plt
escape_time_gpu = cuda.jit(device=True)(escape_time)
@cuda.jit
def mandelbrot_gpu(M, real_min, real_max, imag_min, imag_max):
"""Calculate the Mandelbrot set on the GPU.
Parameters
----------
M : numpy.ndarray
a two-dimensional integer array that will contain the
escape times for each point.
real_min: float
minimum value on the real axis
real_max: float
maximum value on the real axis
imag_min: float
minimum value on the imaginary axis
imag_max: float
maximum value on the imaginary axis
"""
ny, nx = M.shape
i, j = cuda.grid(2)
if i < ny and j < nx:
dx = (real_max - real_min) / nx
dy = (imag_max - imag_min) / ny
p = real_min + dx * i + (imag_min + dy * j) * 1j
M[i, j] = escape_time_gpu(p, 300)
M = numpy.zeros((2048, 2048), dtype=numpy.int32)
block = (32, 32)
grid = (M.shape[0] // block[0] if M.shape[0] % block[0] == 0
else M.shape[0] // block[0] + 1,
int(M.shape[0] // block[1] if M.shape[1] % block[1] == 0
else M.shape[1] // block[1] + 1))
import time
start_time = time.time()
mandelbrot_gpu[grid, block](M, -2.2, 1.5, -1.5, 1.5)
plt.imshow(M, interpolation="nearest")
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
34.46900153160095/0.40889859199523926
スポンサーリンク
スポンサーリンク