In [1]:
#########################################################
##
##  NOTE: Evaluate this cell first!
##
##  Click on it, then press Shift+Enter (or Ctrl+Enter)
##
##  Instructions for starting the slides should appear below.
##  If they don't, try evaluating this cell again.
##  
#########################################################

from helpers import init, timing_plot
init()
Loading BokehJS ...

Optimised Primes¶

Emlyn Corrin

Why?¶

  • Online programming contests (Project Euler etc.)
  • As a mathematical or programming exercise
  • Because it's fun!

What is a prime?¶

A prime number (or a prime) is a natural number greater than 1 that has no positive divisors other than 1 and itself.

— Wikipedia

How would that look in code?¶

In [2]:
# A prime number (or a prime) is a natural number greater than 1
# that has no positive divisors other than 1 and itself.

def is_prime(number):
  if not number > 1:         # If it's not greater than 1
    return False             # It can't be a prime
  for d in range(2, number): # Let's check every possible divisor between 2 and number-1
    if number % d == 0:      # If the remainder when dividing number by d is zero
      return False           # It's not a prime
  return True                # If we get this far, it must be a prime

Let's generate a few¶

In [3]:
# All primes less than 20
[i for i in range(20) if is_prime(i)]
Out[3]:
[2, 3, 5, 7, 11, 13, 17, 19]
In [4]:
# Primes less than 1000 (print it otherwise Jupyter only displays 1 number per line)
print([i for i in range(1000) if is_prime(i)])
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

Let's make it a bit more flexible¶

Currently we can only check if a particular number is prime.
Let's turn it into a generator function that returns a sequence of primes.
This will allow us to do more things, like:

  • Generate the first $n$ primes
  • Generate primes up to a certain size
  • Generate primes until some other condition is met
  • Optimise it better later (we might not need to check every number)
In [5]:
from itertools import count

# For now let's just loop over all numbers and call is_prime on each one,
# we'll worry about optimising this later.
def first_try():
  for n in count(): # Loop over all positive integers
    if is_prime(n): # Check each one to see if it's prime
      yield n       # If so, yield it (return it and continue)
In [6]:
from itertools import islice

# First 20 primes:
[p for p in islice(first_try(), 20)]
Out[6]:
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71]
In [7]:
from itertools import takewhile

# Primes less than 50
[p for p in takewhile(lambda x: x < 50, first_try())]
Out[7]:
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]

But how fast is it?¶

In [8]:
timing_plot(first_try)

Can we make it faster?¶

We know 2 is the only even prime, so why not skip even numbers (apart from 2)?¶

In [9]:
from itertools import count

def skip_even():
  def is_prime(number):           # Let's a local is_prime that skips even numbers
    for d in range(3, number, 2): # check odd divisors between 3 and n-1
      if number % d == 0:         # If there's no remainder, it divides,
        return False              # so we haven't got a prime
    return True                   # If we get this far, it must be prime
  yield 2                         # Make sure we start by yielding the only even prime, 2
  for n in count(3, 2):           # Then loop over odd integers from 3 upwards
    if is_prime(n):               # Checking each (odd) number to see if it's prime
      yield n

How much faster is it?¶

In [10]:
timing_plot(skip_even)

Can we reduce the number of checks further?¶

We could also skip multiples of 3¶

Would give us another factor of $\approx 3/2$, or 1.5 (not 3, because half of the multiples of 3 (the even ones) are already skipped from before), but it would also complicate the code quite a bit.

Is there a way we can we do better than that?

Yes!¶

Factors always come in pairs:
if $n$ has a factor $f$, that means $n = f * g$,
and therefore $g$ must also be a factor

Now, either $f$ and $g$ are both the same and equal to $\sqrt n$,
or one of them must be less than $\sqrt n$.

They can't both be greater than $\sqrt n$.

So if $n$ has any prime factors, at least one of them must always be $\leq \sqrt n$,
and therefore we can stop checking once we reach $\sqrt n$.

Let's write the code for that:

In [11]:
from itertools import count
from math import sqrt

def to_sqrt():
  def is_prime(number):
    limit = int(sqrt(number))        # The highest number we have to check
    for d in range(3, limit + 1, 2): # check odd divisors from 3 to limit
      if number % d == 0:            # If there's no remainder, it divides,
        return False                 # so we haven't got a prime
    return True                      # If we get this far, it must be prime
  yield 2                            # Start with 2
  for n in count(3, 2):              # Then do the odd numbers from 3 upwards
    if is_prime(n):                  # checking each to see if it's prime
      yield n

How much faster is this?¶

In [12]:
timing_plot(to_sqrt)

Is this the best we can do?¶

We are still checking more numbers than necessary...

For example, once we've tested for divisibility by 3 and 5,
we shouldn't need to test their multiples (e.g. 9, 15, 21, 25, 30, 45... etc).

i.e. we only need to check for divisibility by prime numbers.

What about storing a list of primes so far, and only checking those?¶

In [13]:
from itertools import count

def check_primes():
  yield 2                       # Initially yield 2, then we only consider odd numbers
  primes = []                   # Keep a list of all the primes seen so far
  for candidate in count(3, 2): # Let's check all odd number starting from 3
    isprime = True              # Start by assuming it is a prime
    for p in primes:            # Then start going through all our known primes
      if p * p > candidate:     # If the next prime is > sqrt(candidate)
        break                   # No need to continue looking at higher primes
      if candidate % p == 0:    # Else, if candidate is divisible by this prime
        isprime = False         # Our candidate was not a prime after all
        break                   # And stop looking at more primes
    if isprime:                 # If our candidate turned out to be a prime number
      yield candidate           # Yield it to the caller
      primes.append(candidate)  # And add it to the end of our list of primes
In [14]:
timing_plot(check_primes)

What next?¶

Test dividing is (relatively) slow. Instead of test dividing candidate primes, we can generate and eliminate the composite numbers, leaving behind the primes.

The sieve of Eratosthenes¶

  1. start with a grid of numbers, from 2 to max_prime
  2. find first (next) unmarked number, return that as a prime
  3. mark all multiples of it (actually just from $n^2$ onwards)
  4. go back to step 2.
In [15]:
def simple_sieve(max_prime):
  sieve = [True] * max_prime             # Create the "sieve" (an array of booleans)
  for i in range(2, max_prime):          # Loop over the cells of the sieve from 2
    if sieve[i]:                         # If this cell is True
      yield i                            # It's a prime
      for j in range(2*i, max_prime, i): # So loop over all its multiples
        sieve[j] = False                 # and mark them as non-prime
In [16]:
timing_plot(simple_sieve)
In [17]:
def improved_sieve(max_prime):
    yield 2                                      # Yield the only even prime
    sieve = [True] * (max_prime // 2)            # Create sieve of only odd numbers (half the size)
    for i in range(3, max_prime, 2):             # Loop over only odd numbers from 3
        if sieve[i//2]:                          # If this cell is True
            yield i                              # It's a prime
            for j in range(i*i, max_prime, i*2): # Loop over odd multiples starting from its square
                sieve[j//2] = False              # and mark them as non-prime
In [18]:
timing_plot(improved_sieve)

Problems?¶

Memory use¶

  • Use packed data structure (e.g. struct module), encode 8 cells/byte
  • Also skip multiples of 3 (only check numbers of form $6n \pm 1$)

Need to allocate storage upfront¶

Often don't know in advance how much to allocate (e.g. first 100k primes)

What can we do about it?¶

What about switching things around… for each prime, we store the next multiple higher than the current candidate, then we just have to check if candidate is in the list, not multiple test divisions per candidate. For each multiple in the list, we store the original prime, so that when we reach it, we we can add it to generate the next multiple. But it could be a multiple of more than one prime, so we have to store a list of source primes:

In [19]:
from itertools import count

def unbounded_sieve():
  state = {}
  for candidate in count(2):
    if candidate in state:
      for factor in state[candidate]:
        if candidate + factor in state:
          state[candidate + factor].append(factor)
        else:
          state[candidate + factor] = [factor]
      del state[candidate]
    else:
      yield candidate
      state[2 * candidate] = [candidate]
In [20]:
timing_plot(unbounded_sieve)

We can make a few optimisations:
Defaultdict so we don’t have to check if a number is present
We skip even numbers, and therefore even multiples of primes
When we find a prime, p, the first multiple we have to add to the state is p^2, because smaller multiples will have another factor less than p (p*q, where q < p).

In [21]:
from collections import defaultdict
from itertools import count

def unbounded_sieve2():
  yield 2
  state = defaultdict(list)
  for candidate in count(3, 2):
    if candidate in state:
      for inc in state[candidate]:
        state[candidate + inc].append(inc)
      del state[candidate]
    else:
      yield candidate
      state[candidate * candidate] = [2 * candidate]
In [22]:
timing_plot(unbounded_sieve2)

But!¶

If you really need fast primes, don't reinvent the wheel! A properly optimised native C library is still much faster...

In [23]:
from pyprimesieve import primes

def library(n):
  return primes(n)

timing_plot(library)