Saturday, 9 December 2017

Generators have state

Python’s generator expressions (discussed in two previous posts) are very useful for stream programming. But sometimes, you want some state to be preserved between calls. That is where we need full Python generators.

A generator expression is actually a special case of a generator that can be written inline. The generator expression (<expr> for i in <iter>) is shorthand for the generator:
def some_gen():
    for i in <iter>:
        yield <expr>
For example, (n*n for n in count(1)) can be equivalently written as:
from itertools import * 

def squares():
    for n in count(1):
        yield n*n

print_for(squares())
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
The yield statement acts somewhat like a return in providing its value. But the crucial difference is what happens on the next call. An ordinary function starts off again from the top; a generator starts again directly after the previous yield. This becomes important if the generator includes some internal state: this state is maintained between calls. That is, we can have memory of previous state on the next call. This is particularly useful for recurrence relations, where the current value is expressed in terms of previous values (kept as remembered state).

Running total

The accumulate() generator provides a running total of its iterator argument. We can write this as an explicit sum:
$$ T_N = \sum_{i=1}^N x_i$$ We can instead write this sum as a recurrence relation: $$ T_0 = 0 ; T_n = T_{n-1} + x_n$$ Expanding this out explicitly we get $$T_0 = 0 ; T_1 = T_0 + x_1 = 0 + x_1 = x_1 ; T_2 = T_1 + x_2 = x_1 + x_2 ; \ldots$$ +++T_0+++ is the initial state, which is observed directly. The recurrence term +++T_n+++ tells us what needs to be remembered of the previous state(s), and how to use that to generate the current value.
def running_total(gen):
    tot = 0
    for x in gen:
        tot += x
        yield tot

print_for(repeat(3))
print_for(running_total(repeat(3)))
print_for(running_total(count(1)))
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
3, 6, 9, 12, 15, 18, 21, 24, 27, 30, ...
1, 3, 6, 10, 15, 21, 28, 36, 45, 55, ...

Factorial

Similarly we can write the factorial as an explicit product:$$N! = F_N = \prod_{i=1}^N i$$ And we can write it as a recurrence relation: $$ F_1 = 1; F_n = n F_{n-1} $$ +++F_1+++ is the initial state, which should be the first output of the generator. We can yield this directly on the first call, then (on the next call) go into a loop yielding the following values.
def fact():
    f = 1
    yield f
    for n in count(2):
        f *= n
        yield f

print_for(fact())
1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800, ...
We can remove the special case, and instead calculate the next value after the yield in the loop. The subsequent call then picks up at that calculation.
def fact():
    f = 1
    for n in count(2):
        yield f
        f *= n

print_for(fact())
1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800, ...

Fibonacci numbers

The perennially popular Fibonacci numbers are naturally defined using a recurrence relation, involving two previous states:$$F_1 = F_2 = 1 ; F_n = F_{n-1} + F_{n-2}$$ This demonstrates how we can store more state than just the result of the previous yield.
def fib(start=(1,1)):
    a,b = start
    while True:
        yield a
        a,b = b,a+b

print_for(fib(),20)
print_for(fib((0,3)),20)
1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, ...
0, 3, 3, 6, 9, 15, 24, 39, 63, 102, 165, 267, 432, 699, 1131, 1830, 2961, 4791, 7752,
 12543, ...

Logistic map

Difference equations, discrete-time analogues of differential equations, are a form of recurrence relation: the value of the state at the next time step is defined in terms of its value at previous timesteps.

The logistic map is a famous difference equation, exhibiting a range of periodic and chaotic behaviours, depending on the value of its paramenter +++r+++. $$x_0 \in (0,1) ; x_{n+1} = r x_n(1-x_n)$$
def logistic_map(r=4):
    x = 0.1
    while True:
        yield x
        x = r * x * (1-x)

print_for(logistic_map())
print_for(logistic_map(2))
0.1, 0.36000000000000004, 0.9216, 0.28901376000000006, 0.8219392261226498, 
 0.5854205387341974, 0.970813326249438, 0.11333924730376121, 0.4019738492975123,
 0.9615634951138128, ...
0.1, 0.18000000000000002, 0.2952, 0.41611392, 0.4859262511644672, 0.49960385918742867, 
 0.49999968614491325, 0.49999999999980305, 0.5, 0.5, ...
We can plot these values to see the chaotic (+++r=4+++) and periodic (+++r=2+++) behaviours over time. (The %matplotlib inline “magic” allows the plot to display in a Jupyter notebook.)
%matplotlib inline
import matplotlib.pyplot as py 

py.plot(list(islice(logistic_map(),200)))
py.plot(list(islice(logistic_map(3.5),200)))
We can also produce the more usual plot, with the parameter +++r+++ running along the +++x+++-axis, highlighting the various areas of periodicity and chaos.
from numpy import arange

start = 2.8
stop = 4
step = (stop-start)*0.002

skip = int(1/step) # no of initial vals to skip (converged to attractor)
npts = 150 # no of values to plot per value of lambda

for r in arange(start,stop,step):
    yl = logistic_map(r)
    list(islice(yl,skip)) # consume first few items
    py.scatter(list(islice(repeat(r,npts),npts)), list(islice(yl,npts)), marker='.', s=1) 

py.xlim(start,stop)
py.ylim(0,1)
(Here I have used s=1 to get a small point, rather than use a comma marker to get a pixel, because there is currently a bug in the use of pixels in scatter plots.)

Faster +++\pi+++

Many series for generating +++\pi+++ converge very slowly.  One that converges extremely quickly is:$$ \pi = \sum_{i=0}^\infty \frac{(i!)^2 \, 2^{i+1}}{(2i+1)!}$$We could use generators for each component of the term to code this us as:
import operator

def pi_term():
    fact = accumulate(count(1), operator.mul)
    factsq = (i*i for i in fact)

    twoi1 = (2**(i+1) for i in count(1))

    def fact2i1():
        i,f = 3,6
        while True:
            yield f
            f = f * (i+1) * (i+2)
            i += 2

    yield 2   # the i=0 term (needed because of 0! = 1 issues)
    for i in map(lambda x,y,z: x*y/z, factsq, twoi1, fact2i1()):
        yield i

print_for(accumulate(pi_term()),40)
2, 2.6666666666666665, 2.933333333333333, 3.0476190476190474, 3.098412698412698,
 3.121500721500721, 3.132156732156732, 3.1371295371295367, 3.1394696806461506, 
 3.140578169680336, 3.141106021601377, 3.1413584725201353, 3.1414796489611394, 
 3.1415379931734746, 3.1415661593449467, 3.1415797881375944, 3.1415863960370602, 
 3.141589605588229, 3.1415911669915006, 3.1415919276751456, 3.1415922987403384,
 3.141592479958223, 3.1415925685536337, 3.1415926119088344, 3.1415926331440347,
 3.1415926435534467, 3.141592648659951, 3.14159265116678, 3.141592652398205,
 3.1415926530034817, 3.1415926533011587, 3.141592653447635, 3.141592653519746,
 3.1415926535552634, 3.141592653572765, 3.1415926535813923, 3.141592653585647,
 3.141592653587746, 3.141592653588782, 3.141592653589293, ...
However, these terms get large, and take a long time to calculate.
import timeit

%time [ i for i in islice(pi_term(),10000) if i < 0 ];
Wall time: 10.9 s
We can instead write the terms as a recurrence relation.

Let $$F_{k-1} = \frac{((k-1)!)^2 \, 2^{k}}{(2k-1)!}$$ Then $$\begin{eqnarray} F_{k} &=& \frac{((k)!)^2 \, 2^{k+1}}{(2k+1)!} \ &=& \frac{((k-1)! k)^2 \, 2 \times 2^{k}}{(2k-1)!(2k)(2k+1)} \ &=& \frac{((k-1)!)^2 \, 2^{k} 2k^2}{(2k-1)! 2k(2k+1)} \ &=& F_{k-1} \frac{k}{2k+1} \end{eqnarray}$$ So $$F_0 = 2 ; F_{n} = F_{n-1} \frac{n}{2n+1}$$
import math

def pi_term_rec():
    pt = 2
    for n in count(1):
        yield pt
        pt = pt * n/(2*n+1)

print_for(pi_term_rec(),30)
print_for(accumulate(pi_term_rec()),30)
print(math.pi)
We can see how quickly the terms shrink.
2, 0.6666666666666666, 0.26666666666666666, 0.1142857142857143, 0.0507936507936508,
 0.02308802308802309, 0.010656010656010658, 0.004972804972804974, 0.002340143516614105,
 0.0011084890341856288, 0.0005278519210407756, 0.0002524509187586318,
 0.00012117644100414327, 5.8344212335328244e-05, 2.816617147222743e-05,
 1.3628792647851983e-05, 6.607899465625204e-06, 3.209551169017956e-06,
 1.561403271414141e-06, 7.606836450479149e-07, 3.710651927063e-07,
 1.8121788481005348e-07, 8.85954103515817e-08, 4.335520081034849e-08,
 2.123520039690538e-08, 1.0409411959267343e-08, 5.1065039800179424e-09,
 2.5068292265542627e-09, 1.2314248832196377e-09, 6.052766375147372e-10, ...
2, 2.6666666666666665, 2.933333333333333, 3.0476190476190474, 3.098412698412698,
 3.121500721500721, 3.132156732156732, 3.1371295371295367, 3.1394696806461506,
 3.140578169680336, 3.141106021601377, 3.1413584725201353, 3.1414796489611394,
 3.1415379931734746, 3.1415661593449467, 3.1415797881375944, 3.1415863960370602,
 3.141589605588229, 3.1415911669915006, 3.1415919276751456, 3.1415922987403384,
 3.141592479958223, 3.1415925685536337, 3.1415926119088344, 3.1415926331440347,
 3.1415926435534467, 3.141592648659951, 3.14159265116678, 3.141592652398205,
 3.1415926530034817, ...
3.141592653589793
Not only is this code simpler, it is also much faster:
%time [ i for i in islice(pi_term_rec(),10000) if i < 0 ]
Wall time: 3.98 ms

Sieving for primes

No discussion of generators would be complete without including a prime number generator. The standard algorithm is quite straightforward:
  • Generate consecutive whole numbers, and test each for divisibility.  If the current number isn’t divisible by anything, yield a new prime.
  • Only test for divisibility by primes, and only up to the square root of the number being tested; this requires keeping a record of the primes found so far.
  • Optimisation: treat 2 as a special case, and generate and test only the odd numbers.
def primessqrt():
    primessofar = []
    yield 2
    for n in count(3,2): # check odd numbers only, starting with 3
        sqrtn = int(math.sqrt(n))
        testprimes = takewhile(lambda i: i<=sqrtn, primessofar)
        isprime = all(n % p for p in testprimes) # n % p == 0 if n is divisible
        if isprime:  # if new prime, add to list and yield, else continue
            yield n
            primessofar.append(n)

print_while(primessqrt(), 200)
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
 89,  97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
 181, 191, 193, 197, 199, ...
This algorithm is often referred to as the Sieve of Eratosthenes, but the true sieve is somewhat different in its operation.  The sieve doesn’t require any square roots or divisions: it uses addition only, moving a marker through the numbers, striking out the multiples.  Using generators we can do this lazily, using a dictionary to store which marks have struck out a particular value (for example, when we get to 15, the dictionary entry will show that it has been struck out by 3).
  • Generate consecutive whole numbers, and check whether each one’s dictionary entry has any markers.
  • If there are no markers, yield a new prime +++p+++; start a new marker to strike out multiples of +++p+++ (start it at +++p^2+++, for the same reason the previous algorithm only needs to test values up to +++\sqrt n+++).
  • If there are markers (such as for 15), the value isn’t prime; move each marker on to the next value it strikes out (so for 15, move the 3 marker on 6 places to strike out 21: each marker is moved on by twice its value, since we are optimising by not considering even values)
  • Delete the current dictionary entry (so that the dictionary grows only as the number of primes found so far, rather than as the number of values checked, speeding up dictionary access).
from collections import defaultdict
def primesieve():
    yield 2
    sieve = defaultdict(set) # dict of n:{divisor} elems
    for n in count(3,2): # check odd numbers only
        if sieve[n] : # there are divisors, so not prime
            for d in sieve[n]: # move the sieve markers on
                sieve[n+d].add(d) 
        else: #set empty, so prime
            sieve[n*n].add(2*n)
            yield n
        # remove current dict item
        del sieve[n]

print_for(primesieve(), 100)
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277,
 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389,
 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499,
 503, 509, 521, 523, 541, ...
And this “true” sieve is faster than the divisor version.
%time [ i for i in islice(primessqrt(),100000) if i < 0 ]
%time [ i for i in islice(primesieve(),100000) if i < 0 ]
Wall time: 4.69 s
Wall time: 735 ms
So full Python generators are even more fun than generator expressions! And in fact, there’s yet more fun to be had with generators, because it is possible to send values to them on each call, too; but that’s another story for another time.

No comments:

Post a Comment