Jupyter Snippet CB2nd 02_numba

Jupyter Snippet CB2nd 02_numba

5.2. Accelerating pure Python code with Numba and just-in-time compilation

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
size = 400
iterations = 100
def mandelbrot_python(size, iterations):
    m = np.zeros((size, size))
    for i in range(size):
        for j in range(size):
            c = (-2 + 3. / size * j +
                 1j * (1.5 - 3. / size * i))
            z = 0
            for n in range(iterations):
                if np.abs(z) <= 10:
                    z = z * z + c
                    m[i, j] = n
                else:
                    break
    return m
m = mandelbrot_python(size, iterations)
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(np.log(m), cmap=plt.cm.hot)
ax.set_axis_off()

png

%timeit mandelbrot_python(size, iterations)
5.45 s ± 18.6 ms per loop (mean ± std. dev. of 7 runs,
    1 loop each)
from numba import jit
@jit
def mandelbrot_numba(size, iterations):
    m = np.zeros((size, size))
    for i in range(size):
        for j in range(size):
            c = (-2 + 3. / size * j +
                 1j * (1.5 - 3. / size * i))
            z = 0
            for n in range(iterations):
                if np.abs(z) <= 10:
                    z = z * z + c
                    m[i, j] = n
                else:
                    break
    return m
mandelbrot_numba(size, iterations)
%timeit mandelbrot_numba(size, iterations)
34.5 ms ± 59.4 µs per loop (mean ± std. dev. of 7 runs,
    10 loops each)
def initialize(size):
    x, y = np.meshgrid(np.linspace(-2, 1, size),
                       np.linspace(-1.5, 1.5, size))
    c = x + 1j * y
    z = c.copy()
    m = np.zeros((size, size))
    return c, z, m
def mandelbrot_numpy(c, z, m, iterations):
    for n in range(iterations):
        indices = np.abs(z) <= 10
        z[indices] = z[indices] ** 2 + c[indices]
        m[indices] = n
%%timeit -n1 -r10 c, z, m = initialize(size)
mandelbrot_numpy(c, z, m, iterations)
174 ms ± 2.91 ms per loop (mean ± std. dev. of 10 runs,
    1 loop each)