python / numba.jit / numba.cuda.jit速度比較

python, numba.jit. numba.njitとnumba.cuda.jitの速度比較をしてみた。GPUパワーをまざまざと見せつけられる結果となった。

import numpy as np
from pylab import imshow, show
from timeit import default_timer as timer

def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in range(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in range(width):
    real = min_x + x * pixel_size_x
    for y in range(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 3.939450 s
create_fractal(-2.0, -1.7, -0.1, 0.1, image, 20) 
imshow(image)
show()

マンデルブロ描画(numba.jit版)

from numba import jit

@jit
def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in range(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters

@jit
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in range(width):
    real = min_x + x * pixel_size_x
    for y in range(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 0.190077 s
3.939450/0.190077
20.725548067362173

numba.jit版はpython版の約21倍高速という結果だった。

@jit(nopython=True, parallel=True, fastmath=True)
def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in range(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters

@jit(nopython=True, parallel=True, fastmath=True)
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in range(width):
    real = min_x + x * pixel_size_x
    for y in range(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 0.181287 s

マンデルブロ集合描画(numba.njit版)

from numba import njit, prange

@njit(nogil=True, fastmath=True)
def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in prange(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters

@njit(nogil=True, fastmath=True)
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in prange(width):
    real = min_x + x * pixel_size_x
    for y in prange(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 0.157178 s

numba.njit版はnumba.jit版よりチョロっとだけ速かった。

from numba import njit, prange

@jit(nogil=True, fastmath=True, nopython=True)
def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in prange(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters

@jit(nogil=True, fastmath=True, nopython=True)
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in prange(width):
    real = min_x + x * pixel_size_x
    for y in prange(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 0.158555 s

icc_rtをインストールすると高速になるというので入れてみた。

!conda install -y -c numba icc_rt
Solving environment: done

## Package Plan ##

  environment location: /root/.pyenv/versions/miniconda3-4.3.30/envs/caffe2

  added / updated specs: 
    - icc_rt


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    icc_rt-2018.0.2            |                0         9.8 MB  numba
    certifi-2018.4.16          |           py36_0         142 KB
    ------------------------------------------------------------
                                           Total:         9.9 MB

The following NEW packages will be INSTALLED:

    icc_rt:  2018.0.2-0       numba      

The following packages will be UPDATED:

    certifi: 2018.4.16-py36_0 conda-forge --> 2018.4.16-py36_0 
    openssl: 1.0.2o-0         conda-forge --> 1.0.2o-h20670df_0


Downloading and Extracting Packages
icc_rt-2018.0.2      |  9.8 MB | ####################################### | 100% 
certifi-2018.4.16    |  142 KB | ####################################### | 100% 
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
from numba import njit

@jit(nogil=True, fastmath=True, nopython=True, parallel=True)
def mandel(x, y, max_iters):
  """
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
  """
  c = complex(x, y)
  z = 0.0j
  for i in range(max_iters):
    z = z*z + c
    if (z.real*z.real + z.imag*z.imag) >= 4:
      return i

  return max_iters

@jit(nogil=True, fastmath=True, nopython=True, parallel=True)
def create_fractal(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height
    
  for x in range(width):
    real = min_x + x * pixel_size_x
    for y in range(height):
      imag = min_y + y * pixel_size_y
      color = mandel(real, imag, iters)
      image[y, x] = color
image = np.zeros((1024, 1536), dtype = np.uint8)
start = timer()
create_fractal(-2.0, 1.0, -1.0, 1.0, image, 20) 
dt = timer() - start

print ("Mandelbrot created in %f s" % dt)
imshow(image)
show()
Mandelbrot created in 0.165475 s

色々試してみたが、高速になるどころか速度が落ちた。

マンデルブロー集合描画(cuda.jit版)

from numba import cuda
from numba import *

mandel_gpu = cuda.jit(restype=uint32, argtypes=[f8, f8, uint32], device=True)(mandel)

@cuda.jit(argtypes=[f8, f8, f8, f8, uint8[:,:], uint32])
def mandel_kernel(min_x, max_x, min_y, max_y, image, iters):
  height = image.shape[0]
  width = image.shape[1]

  pixel_size_x = (max_x - min_x) / width
  pixel_size_y = (max_y - min_y) / height

  startX, startY = cuda.grid(2)
  gridX = cuda.gridDim.x * cuda.blockDim.x;
  gridY = cuda.gridDim.y * cuda.blockDim.y;

  for x in range(startX, width, gridX):
    real = min_x + x * pixel_size_x
    for y in range(startY, height, gridY):
      imag = min_y + y * pixel_size_y 
      image[y, x] = mandel_gpu(real, imag, iters)
gimage = np.zeros((1024, 1536), dtype = np.uint8)
blockdim = (32, 8)
griddim = (32,16)

start = timer()
d_image = cuda.to_device(gimage)
mandel_kernel[griddim, blockdim](-2.0, 1.0, -1.0, 1.0, d_image, 20) 
d_image.to_host()
dt = timer() - start

print ("Mandelbrot created on GPU in %f s" % dt)

imshow(gimage)
show()
Mandelbrot created on GPU in 0.006725 s
3.939450/0.006725
585.7918215613382
0.157178/0.006725
23.372193308550187

numba.cuda.jit版はpython版の586倍高速、numba.njit版の23倍高速という結果に終わった。圧倒的なGPUパワーを見せつけられた格好だ。

参考サイトA NumbaPro Mandelbrot Example