Jupyter Snippet CB2nd 02_numba
Jupyter Snippet CB2nd 02_numba
5.2. Accelerating pure Python code with Numba and just-in-time compilation
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
size = 400
iterations = 100
def mandelbrot_python(size, iterations):
m = np.zeros((size, size))
for i in range(size):
for j in range(size):
c = (-2 + 3. / size * j +
1j * (1.5 - 3. / size * i))
z = 0
for n in range(iterations):
if np.abs(z) <= 10:
z = z * z + c
m[i, j] = n
else:
break
return m
m = mandelbrot_python(size, iterations)
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.log(m), cmap=plt.cm.hot)
ax.set_axis_off()
%timeit mandelbrot_python(size, iterations)
5.45 s ± 18.6 ms per loop (mean ± std. dev. of 7 runs,
1 loop each)
from numba import jit
@jit
def mandelbrot_numba(size, iterations):
m = np.zeros((size, size))
for i in range(size):
for j in range(size):
c = (-2 + 3. / size * j +
1j * (1.5 - 3. / size * i))
z = 0
for n in range(iterations):
if np.abs(z) <= 10:
z = z * z + c
m[i, j] = n
else:
break
return m
mandelbrot_numba(size, iterations)
%timeit mandelbrot_numba(size, iterations)
34.5 ms ± 59.4 µs per loop (mean ± std. dev. of 7 runs,
10 loops each)
def initialize(size):
x, y = np.meshgrid(np.linspace(-2, 1, size),
np.linspace(-1.5, 1.5, size))
c = x + 1j * y
z = c.copy()
m = np.zeros((size, size))
return c, z, m
def mandelbrot_numpy(c, z, m, iterations):
for n in range(iterations):
indices = np.abs(z) <= 10
z[indices] = z[indices] ** 2 + c[indices]
m[indices] = n
%%timeit -n1 -r10 c, z, m = initialize(size)
mandelbrot_numpy(c, z, m, iterations)
174 ms ± 2.91 ms per loop (mean ± std. dev. of 10 runs,
1 loop each)