Tuesday, 21 November 2017

Fun with Python generator expressions

In the past, I have rhapsodised over Python’s list comprehensions. I had also heard of its generators, but never looked at them seriously. Recently I have been thinking about a problem in stream programming, where I will need the generators’ lazy evaluation. So I have been looking at them in more detail. And discovered that (in Python 3 at least) the list comprehension is just syntactic sugar for a generator expression output turned into a list. That is: [ <expr> for i in <iter> if <cond> ] is just syntactic sugar for list( <expr> for i in <iter> if <cond> ). Let’s look at this in more detail.

With a list comprehension, I can construct, say, a list of the first 10 square numbers.
squares = [n*n for n in range(1,11)]
print(squares)
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
Since this is just syntactic sugar, it is the same as writing:
squares = list(n*n for n in range(1,11))
print(squares)
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
So far, so trivial. But the advantage of using the generator form is that it lazily iterates (it "generates") its items one at a time, as they are asked for. If you do not ask, it does not produce. This is useful if you don’t know how many items you want, so cannot specify the stop value of the range(). A generator can be unbounded.
from itertools import *
squares = (n*n for n in count(1))
for s in squares:
    print(s, end=", ")
    if s >= 100:
        print('...')
        break
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
Here count() imported from itertools is a iterable like range() except that it does not have a stop value. It is a generator that just keeps incrementing until you stop asking it for more values.

The module itertools has many such useful functions you can use with potentially infinite generators. For example, we can islice() off the first few items of a potentially unbounded generator, to give a bounded generator that returns only those items. Here we slice off the first 10 values:
squares = (n*n for n in count(1))
firstfew = islice(squares,10)
print(list(firstfew))
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
This is fine if we know how many items we want. Sometimes instead we want all the items up to a particular value. We can use takewhile() for this, here to get the squares less than 150.
squares = (n*n for n in count(1))
firstfew = takewhile(lambda n: n<=150, squares)
print(list(firstfew))
[1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144]
Let’s use islice() and takewhile to define a couple of "pretty print" functions:
# print the first n items of the generator, comma separated
def print_for(gen,n=10):
    print(*list(islice(gen,n)), sep=', ', end=", ...\n")

# print the generator up to value nmax, comma separated
def print_while(gen,nmax=250):
    print(*list(takewhile(lambda n: n<=nmax, gen)), sep=', ', end=", ...\n")

print_for(n*n for n in count(1))
print_while(n*n for n in count(1))
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, ...
squares = (n*n for n in count(1))
print_for(squares)

squares = (n*n for n in count(1))
print_while(squares)
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, ...
There are analogous functions dropwhile(), which drops items until its predicate is True, and filterfalse(), which drops all items for which its predicate is false.
print_for(dropwhile(lambda n: n<=5, count(1)))  # drop items until they get to 5
print_for(filterfalse(lambda n: n%3, count(1))) # filter out items not divisible by 3
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, ...
3, 6, 9, 12, 15, 18, 21, 24, 27, 30, ...

Restarting generators

I have been quite careful above to keep redefining squares. This is because once an item is consumed, it is gone. If I take a generator instance, like squares, and slice off the first ten item, then slice again, I get the next ten items. For example:
squares = (n*n for n in count(1))

print_for(squares)
print_for(squares)  # continues from next value of same generator

print_for(n*n for n in count(1))
print_for(n*n for n in count(1)) # restarts, with new generator
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
121, 144, 169, 196, 225, 256, 289, 324, 361, 400, ...
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
You do need to ensure you consume the generated items for this to occur. Consider:
squares = (n*n for n in count(1))
# slice off the first 10?
islice(squares, 10)
print_for(squares)  # maybe not what is expected
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
The islice() produced a generator. But since nothing accessed that generator, nothing actually got sliced off squares. Compare:
squares = (n*n for n in count(1))
# slice off and consume the first 10
list(islice(squares, 10))
print_for(squares)
121, 144, 169, 196, 225, 256, 289, 324, 361, 400, ...

tee() for two

Once it’s gone, it’s gone. But what if you want it back? You could always iterate again. But what if the items are expensive to compute? For example, you may be reading a large file, and don’t want to read it all again. Here tee() is useful for "remembering" earlier items. (Note: this doesn’t make +++n+++ copies of an iterator, rather it gives +++n+++ pointers into a single iterator.)
sq_ptr = tee((n*n for n in count(1)))

print_for(sq_ptr[0])
print_for(sq_ptr[0],5)
print_for(sq_ptr[1])
print_for(sq_ptr[1])
print_for(sq_ptr[1],5)
print_for(sq_ptr[0])
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
121, 144, 169, 196, 225, ...
1, 4, 9, 16, 25, 36, 49, 64, 81, 100, ...
121, 144, 169, 196, 225, 256, 289, 324, 361, 400, ...
441, 484, 529, 576, 625, ...
256, 289, 324, 361, 400, 441, 484, 529, 576, 625, ...
One application is building a "sliding window" over the iterated data. Here is an example for a window of size 3:
def triples(iterable):
    a, b, c = tee(iterable,3)
    next(b)
    list(islice(c,2))
    return zip(a, b, c)

print_for(triples(count(1)))
print_for(triples('abcdefghijklmnopqrstuvwxyz'))
(1, 2, 3), (2, 3, 4), (3, 4, 5), (4, 5, 6), (5, 6, 7), (6, 7, 8), (7, 8, 9), (8, 9, 10),
(9, 10, 11), (10, 11, 12), ...
('a', 'b', 'c'), ('b', 'c', 'd'), ('c', 'd', 'e'), ('d', 'e', 'f'), ('e', 'f', 'g'), 
('f', 'g', 'h'), ('g', 'h', 'i'), ('h', 'i', 'j'), ('i', 'j', 'k'), ('j', 'k', 'l'), ...
and here it is for a user-defined window size:
def sliding_window(iterable,n=2):
    windows = tee(iterable,n)
    for i in range(n):
        list(islice(windows[i],i))
    return zip(*windows)

print_for(sliding_window(count(1),4))
print_for(sliding_window('abcdefghijklmnopqrstuvwxyz',3))
(1, 2, 3, 4), (2, 3, 4, 5), (3, 4, 5, 6), (4, 5, 6, 7), (5, 6, 7, 8), (6, 7, 8, 9), 
(7, 8, 9, 10), (8, 9, 10, 11), (9, 10, 11, 12), (10, 11, 12, 13), ...
('a', 'b', 'c'), ('b', 'c', 'd'), ('c', 'd', 'e'), ('d', 'e', 'f'), ('e', 'f', 'g'), 
('f', 'g', 'h'), ('g', 'h', 'i'), ('h', 'i', 'j'), ('i', 'j', 'k'), ('j', 'k', 'l'), ...

Sums and products

The itertools module has a function accumulate() that takes an iterator, and returns a generator that comprises the sum of the values up to that point. We could use this to implement a simple version of count() "the hard way", by using repeat(), which simply repreats its argument endlessly:
print_for(repeat(3))
print_for(accumulate(repeat(1)))
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...
The triangle numbers are 1, 1+2, 1+2+3, ... We can use count() and accumulate() to generate these:
print_for(accumulate(count(1)))
1, 3, 6, 10, 15, 21, 28, 36, 45, 55, ...
accumulate() defaults to summing the items, but other functions can be used.
import operator
factorial = accumulate(count(1), operator.mul)
print_for(factorial)
1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800, ...

Arithmetic series and products

We can use accumulate() to calculate infinite sums and products (or at least, give us the first howevermany terms).

sum of reciprocal squares

$$ \sum_{n=1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6} $$
import math

recip_squares = (1/(n*n) for n in count(1))
print_for(accumulate(recip_squares))

# or all in one go, now with 30 terms:
print_for(accumulate(1/(n*n) for n in count(1)),30)

# it converges rather slowly
print('limit =', math.pi*math.pi/6)
1.0, 1.25, 1.3611111111111112, 1.4236111111111112, 1.4636111111111112, 1.4913888888888889,
1.511797052154195, 1.527422052154195, 1.5397677311665408, 1.5497677311665408, ...
1.0, 1.25, 1.3611111111111112, 1.4236111111111112, 1.4636111111111112, 1.4913888888888889,
1.511797052154195, 1.527422052154195, 1.5397677311665408, 1.5497677311665408, 
1.558032193976458, 1.5649766384209025, 1.5708937981842162, 1.5759958390005426, 
1.580440283444987, 1.584346533444987, 1.587806741057444, 1.5908931608105303, 
1.5936632439130234, 1.5961632439130233, 1.5984308176091684, 1.6004969333116477, 
1.6023872924798896, 1.6041234035910008, 1.6057234035910009, 1.6072026935318293, 
1.6085744356443121, 1.6098499458483937, 1.6110390064904865, 1.6121501176015975, ...
limit = 1.6449340668482264

sum of reciprocal powers of 2

$$ \sum_{n=1}^\infty \frac{1}{2^n} = 1 $$
print_for(accumulate(1/(2**n) for n in count(1)),20)
0.5, 0.75, 0.875, 0.9375, 0.96875, 0.984375, 0.9921875, 0.99609375, 0.998046875, 
0.9990234375, 0.99951171875, 0.999755859375, 0.9998779296875, 0.99993896484375, 
0.999969482421875, 0.9999847412109375, 0.9999923706054688, 0.9999961853027344, 
0.9999980926513672, 0.9999990463256836, ...

factorial

$$ \prod_{i=1}^n i = n! $$
print_for(accumulate((n for n in count(1)), operator.mul))
1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800, ...

a product for +++\pi+++

$$ \prod_{n=1}^\infty \left( \frac{4n^2}{4n^2-1} \right) = \frac{\pi}{2} $$
print_for(accumulate((1/(1-1/(4*n**2)) for n in count(1)), operator.mul),20)

# it also converges rather slowly
print('limit =', math.pi/2)
1.3333333333333333, 1.422222222222222, 1.4628571428571429, 1.4860770975056687, 
1.5010879772784533, 1.51158509600068, 1.5193368144417092, 1.525294998027755, 
1.5300172735634447, 1.533851903321749, 1.5370275801402207, 1.539700671583943, 
1.5419817096159192, 1.5439510349155563, 1.5456684442981097, 1.5471793616434646, 
1.5485189108743247, 1.549714678373069, 1.5507886317191353, 1.5517584807696163, ...
limit = 1.5707963267948966
This is all very neat and nifty, and we have barely scratched the surface of what can be done. But that's probably (more than) enough for now.


No comments:

Post a comment