numba.jit / cython / C++の速度比較

numba.jit、cython, pythonの速度比較をしてみた。コードはここから拝借してきた。

スポンサーリンク

numba.jit, cython, pythonの速度比較

import numpy as np

def trivial_sum2d(arr):
    n, m = arr.shape
    ret = 0.0
    for i in range(n):
        for j in range(m):
            ret += arr[i, j]
    return ret
x = np.arange(100000000).reshape(1000, 100000).astype(np.float32)
%timeit -n 10 -r 1 trivial_sum2d(x)
18.3 s ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
from numba import jit
import numpy as np

@jit
def numba_sum2d(arr):
    n, m = arr.shape
    ret = 0.0
    for i in range(n):
        for j in range(m):
            ret += arr[i, j]
    return ret
%timeit -n 10 -r 1 numba_sum2d(x)
103 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
%load_ext cython
%%cython -a
import numpy as np
import cython
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
cdef double _sum2d(np.ndarray[np.float32_t, ndim=2] arr):
    cdef int n, m
    n = arr.shape[0]
    m = arr.shape[1]
    cdef double ret = 0.0
    for i in range(n):
        for j in range(m):
            ret += arr[i, j]
    return ret

def cython_sum2d(arr):
    return _sum2d(arr)
%timeit -n 10 -r 3 cython_sum2d(x)
102 ms ± 701 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
18.3/.103
177.66990291262138

相変わらずcython最速だが、numba.jitもpythonの約178倍高速という結果になった。しかし、numba.jitの場合はcythonと違い、@jitを加えるだけで高速化されるのが妙味だ。

C++ versionとの速度比較

%%file trivial.cpp
#include <ctime>
#include <cstdlib>
#include <chrono>
#include <iostream>

class Matrix {
    float *data;
public:
    size_t n, m;
    Matrix(size_t r, size_t c): data(new float[r*c]), n(r), m(c) {}
    ~Matrix() { delete[] data; }
    float& operator() (size_t x, size_t y) { return data[x*m+y]; }
    float operator() (size_t x, size_t y) const { return data[x*m+y]; }
};

float sum2d(const Matrix &a) {
    float ret = 0;
    for (size_t i = 0; i < a.n; ++i) 
        for (size_t j = 0; j < a.m; ++j) {
            ret += a(i, j);
        }
    return ret;
}

void fill(Matrix &a) {
    int cnt = -1;
    for (size_t i = 0; i < a.n; ++i)
        for (size_t j = 0; j < a.m; ++j) {
            a(i, j) = ++cnt;
        }
}

int main() {
    srand((unsigned)time(NULL));
    const int n = 1000, m = 100000, T = 10;
    Matrix x(n, m);
    fill(x);
    auto st = std::chrono::system_clock::now();
    float s = 0;
    for (int i = 0; i < T; ++i) {
        s += sum2d(x);
    }
    auto ed = std::chrono::system_clock::now();
    std::chrono::duration<double> diff = ed-st;
    std::cerr << s << std::endl;
    std::cout << T << " loops. average " << diff.count() * 1e3 / T << "ms" << std::endl;
}
Writing trivial.cpp
!g++ -std=c++11 trivial.cpp -o a
!./a
2.2518e+16
10 loops. average 310.743ms
!g++ -O2 -std=c++11 trivial.cpp -o a
!./a
2.2518e+16
10 loops. average 102.593ms

cython versionはc++ versionよりも高速という驚きの結果となった。