Inspired by Explicitly writing a loop in numpy is extremely slow
In the above article, it was said that writing for explicitly would be extremely slow. For example
def matmul1(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
Code like is slower than np.dot
.
%timeit matmul1(a, b)
1 loops, best of 3: 12.9 s per loop
%timeit np.dot(a, b)
10 loops, best of 3: 20.7 ms per loop
It's slow because it's calculated on my laptop. Also, atlas / mkl is not linked.
Now use Numba.
import numba
@numba.jit #Add only here
def matmul1_jit(a, b):
lenI = a.shape[0]
lenJ = a.shape[1]
lenK = b.shape[1]
c = np.zeros((lenI, lenJ))
for i in range(lenI):
for j in range(lenJ):
for k in range(lenK):
c[i, j] += a[i, k] * b[k, j]
return c
It JIT compiles Python code using LLVM, so it can run very fast. The first call includes time to compile, so if you measure the speed on subsequent calls:
%timeit matmul1_jit(a, b)
10 loops, best of 3: 24.4 ms per loop
Just adding one line like this made it about the same as np.dot
(about 20% slower).
Put the whole ipynb in gist. I wish I could embed nbviewer in Qiita.
Recommended Posts