numba cuda jitを使ったマンデルブロー集合描画

GPUとCPUのMandelbrot set描画の速度比較をしてみた。GPU versionにはnumba.cuda.jitを使用した。コードはここから拝借させてもらった。一日も早くこういうコードを自作してみたいものだ。

スポンサーリンク

マンデルブロ描画(CPU version)

def escape_time(p, maxtime):
    """Perform the Mandelbrot iteration until it's clear that p diverges
    or the maximum number of iterations has been reached.
    """
    z = 0j
    for i in range(maxtime):
        z = z ** 2 + p
        if abs(z) > 2:
            return i
    return maxtime
import numpy
import matplotlib.pyplot as plt
import time

start_time = time.time()

maxiter = 300
rlim = (-2.2, 1.5)
ilim = (-1.5, 1.5)
nx = 2048
ny = 2048

dx = (rlim[1] - rlim[0]) / nx
dy = (ilim[1] - ilim[0]) / ny

M = numpy.zeros((ny, nx), dtype=int)

for i in range(ny):
    for j in range(nx):
        p = rlim[0] + j * dx + (ilim[0] + i * dy) * 1j
        M[i, j] = escape_time(p, maxiter)

plt.imshow(M, interpolation="nearest")
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
Elapsed time: 34.46900153160095

マンデルブロー描画(GPU version)

import numpy
from numba import cuda
import matplotlib.pyplot as plt

escape_time_gpu = cuda.jit(device=True)(escape_time)

@cuda.jit
def mandelbrot_gpu(M, real_min, real_max, imag_min, imag_max):
    """Calculate the Mandelbrot set on the GPU.
    
    Parameters
    ----------
    M : numpy.ndarray
        a two-dimensional integer array that will contain the 
        escape times for each point.
    real_min: float
        minimum value on the real axis
    real_max: float
        maximum value on the real axis
    imag_min: float
        minimum value on the imaginary axis
    imag_max: float
        maximum value on the imaginary axis
    """
    ny, nx = M.shape
    i, j = cuda.grid(2)
    
    if i < ny and j < nx:
        dx = (real_max - real_min) / nx
        dy = (imag_max - imag_min) / ny
        p = real_min + dx * i + (imag_min + dy * j) * 1j
        M[i, j] = escape_time_gpu(p, 300)
M = numpy.zeros((2048, 2048), dtype=numpy.int32)
block = (32, 32)
grid = (M.shape[0] // block[0] if M.shape[0] % block[0] == 0 
            else M.shape[0] // block[0] + 1,
        int(M.shape[0] // block[1] if M.shape[1] % block[1] == 0 
            else M.shape[1] // block[1] + 1))
import time

start_time = time.time()
mandelbrot_gpu[grid, block](M, -2.2, 1.5, -1.5, 1.5)
plt.imshow(M, interpolation="nearest")
elapsed_time = time.time() - start_time
print("Elapsed time: {}".format((elapsed_time)))
Elapsed time: 0.40889859199523926
34.46900153160095/0.40889859199523926
84.29718812042833

GPU versionはCPU versionの84倍高速という結果になった。