理解generator生成器

Last Updated: 2023-12-20 10:37:21 Wednesday

-- TOC --

也许只有在像Python这样解释执行的动态编程语言中,才能看到生成器这样的函数接口。这类接口可以在接口内部中间的某个地方停下来,返回一个对象,下一次执行时,从上次返回对象的地方继续!

我们称这类接口为generator,生成器!

生成器基本概念

一个函数可以不断地返回对象,下次从返回点继续执行,forever,或者直到它内部机制执行到完全停止的状态。这样的接口,就像一台记录状态的机器,每次都是从上次返回的位置开始,而不是从新开始。比如生成Fibonacci序列:

>>> def fib():
...   a = 1
...   yield a
...   b = 1
...   yield b
...   while True:
...     yield a+b
...     a,b = b,a+b
...
>>> f = fib()
>>> next(f)
1
>>> next(f)
1
>>> next(f)
2
>>> next(f)
3
>>> next(f)
5
>>> next(f)
8
>>> next(f)
13
>>> next(f)
21
>>> next(f)
34
>>> next(f)
55
>>> f.close()
>>> fib()
<generator object fib at 0x7f0a308a7d80>

调用fib接口返回的是一个generator对象,通过next接口遍历generator。当然,用for循环遍历更加帅气:

>>> for i in fib():
...   if i < 100:
...     print(i)
...   else:
...     break
...
1
1
2
3
5
8
13
21
34
55
89

只要函数内部出现yield关键词,就算永远不会被执行,也是generator!这种generator,在第一次next的时候,就会抛出StopIteration异常。(CS401教授说:function contains yield is actually coroutine!他们有独立的stack。)

每次生成的generator,都是独立的,尽管使用了相同的定义。

当generator内部执行自然结束后,for循环能自己停下来,因为StopIteration被抛出了。

>>> def gtor():
...   yield 1
...   yield 2,3  # tuple
...
>>> for i in gtor():
...   print(i)
...
1
(2, 3)
>>> g = gtor()
>>> next(g)
1
>>> next(g)
(2, 3)
>>> next(g)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

对生成器调用close,之后再next,也会抛出StopIteration异常:

>>> g = gtor()
>>> g.close()
>>> next(g)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

当generator用return结束自己时,抛出的StopIteration会将return的值带出来:

>>> def gtor():
...     yield 1
...     return 99
...
>>> g = gtor()
>>> next(g)
1
>>> try:
...   next(g)
... except StopIteration as e:
...   print(e.value)
...
99

比较神奇的是,generator不仅仅可以yield返回,外部代码还可以用send接口向它发送输入,send也会触发generator的继续执行:

>>> def gs():
...   a = yield 1
...   b = yield 2
...   print(a,b)
...
>>> s = gs()
>>> next(s)
1
>>> s.send('abc')
2
>>> try:
...   s.send('123')
... except StopIteration:
...   pass
...
abc 123

以上代码首先用next启动generator内部代码的执行,然后两个send调用,注意最后会抛出异常。其实,全程用send也可以,只是第一个send必须是send(None)

>>> s = gs()
>>> s.send(None)
1
>>> s.send('abc')
2
>>> s.send('123')
abc 123
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

也就是说,generator在执行,返回,继续执行的过程中,可以实现双向交换数据,实现更复杂的功能。

而且,还可以向generator内部扔异常:

>>> def catch():
...   try:
...     yield 1
...   except Exception as e:
...     print('catch:', str(e))
...
>>> c = catch()
>>> next(c)
1
>>> c.throw(ValueError('aaa'))
catch: aaa
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
StopIteration

可以单独使用yield,后面没有任何对象,此时相当于yield None

生成器的状态

inspect模块中有个getgeneratorstate接口,可以获得生成器的状态,共有4种状态:

getgeneratorstate(generator)
    Get current state of a generator-iterator.

    Possible states are:
      GEN_CREATED: Waiting to start execution.
      GEN_RUNNING: Currently being executed by the interpreter.
      GEN_SUSPENDED: Currently suspended at a yield expression.
      GEN_CLOSED: Execution has completed.

测试代码:

>>> from inspect import getgeneratorstate
>>> g = gtor()
>>> getgeneratorstate(g)
'GEN_CREATED'
>>> next(g)
1
>>> getgeneratorstate(g)
'GEN_SUSPENDED'
>>> try:
...   next(g)
... except:
...   pass
...
>>> getgeneratorstate(g)
'GEN_CLOSED'

yield from

yield直接抛出一个对象,而yield from则是按序抛出可迭代对象中的每个元素,yield from后面只能跟iterables。看到一个实现:用yield from将遍历BSTree变成一个读取generator!

>>> def tyf():
...     a = (i for i in range(3))  # generator expression
...     yield from a
...
>>> t = tyf()
>>> next(t)
0
>>> next(t)
1
>>> next(t)
2

当yield from跟另一个generator时,有一些特殊性质:

def gen123():
    yield 1
    yield 2
    yield 3

def gen789():
    a = yield 7
    a = yield a
    yield a

def my_gen():
    yield from gen123()
    print('take a breath...')
    yield from gen789()

g = my_gen()
print(next(g))
print(next(g))
print(next(g))
print(next(g))
print(g.send(8))
print(g.send(9))
try:
    next(g)
except StopIteration:
    print('end')

运行效果:

1
2
3
take a breath...
7
8
9
end

下面是中序遍历二叉树的Python实现,来自LeetCode第94题:

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def inorderTraversal(self, root: Optional[TreeNode]) -> List[int]:
        def _inorder(root):
            if root:
                yield from _inorder(root.left)
                yield root.val
                yield from _inorder(root.right)
        return list(_inorder(root))

@contextmanager

标准库,contextlib.py,是这样定义这个装饰器的:

def contextmanager(func):
    """@contextmanager decorator.
    Typical usage:
        @contextmanager
        def some_generator(<arguments>):
            <setup>
            try:
                yield <value>
            finally:
                <cleanup>
    This makes this:
        with some_generator(<arguments>) as <variable>:
            <body>
    equivalent to this:
        <setup>
        try:
            <variable> = <value>
            <body>
        finally:
            <cleanup>
    """
    @wraps(func)
    def helper(*args, **kwds):
        return _GeneratorContextManager(func, args, kwds)
    return helper

给generator加上这个@contextmanager装饰器,就可以使用with语句执行generator!

Python的generator展现了如此的多才多艺,能否被灵活应用是考验程序员艺术涵养的一个非重要的参考。

generator生成器与传说中的协程(coroutine)密切相关。

用generator实现__iter__

generator method,用generator的思路实现对象的__iter__接口。

class joy:

    def __init__(self, start,end):
        self.s = start
        self.e = end

    def __iter__(self):
        for i in range(self.s,self.e):
            yield i


a = joy(0,4)
b = joy(5,9)
c = iter(a)
next(c)
for i,j,k in zip(iter(a),iter(b),c):
    print(i,j,k)

c来自a,循环中还有个a生成的iterator,但是它们各自是独立的。

当iterable作为参数时

当iterable作为函数接口的参数时,这个参数可以用一个generator传入!generator就是一个iterable,我们不需要多此一举创建list(或其它iterable对象),然后再将list传入,代码更加简洁优雅!:)

>>> def test(n):
...   print(type(n))
...   for i in n:
...     print(i)
...
>>> test(i for i in range(4))
<class 'generator'>
0
1
2
3

或者:

>>> def g():
...   for i in range(4):
...     yield i
...
>>> test(g())
<class 'generator'>
0
1
2
3

Python很多内置函数接口,都接收iterable作为参数:

>>> sum(i for i in range(8))
28
>>> ''.join(str(i) for i in range(8))
'01234567'
>>> bytes(i for i in range(8))
b'\x00\x01\x02\x03\x04\x05\x06\x07'
>>> tuple(i for i in range(8))
(0, 1, 2, 3, 4, 5, 6, 7)
>>> set(i for i in range(8))
{0, 1, 2, 3, 4, 5, 6, 7}

本文链接:https://cs.pynote.net/sf/python/202211171/

-- EOF --

-- MORE --