回溯算法

Last Updated: 2024-03-11 13:05:55 Monday

-- TOC --

回溯算法本质上就是DFS,深度优先搜索!

由于可能的搜索空间巨大,一般应用于具体问题时,要根据情况做Pruning。

下面是一段计算组合的示例代码:

class Comb:

    def _comb(self, idx, t):
        """ O(2^N)
        If there are p different values in each position,
        the complexity would be O(p^N).
        """
        if idx == self.length:
            yield t
            return
        yield from self._comb(idx+1, t+(self.lst[idx],))
        yield from self._comb(idx+1, t)

    def _combn(self, m, i, t):
        """ O(N^m) """
        if m == self.m:
            yield t
            return
        while i < self.length:
            yield from self._combn(m+1, i+1, t+(self.lst[i],))
            i += 1

    def comb(self, lst, m=None):
        """ Yield combinations/subsets of lst.
        If m is None, yield all subsets.
        """
        self.lst = lst
        self.length = len(lst)
        self.m = m
        if m is not None:
            assert self.length >= m
            yield from self._combn(0, 0, ())
            return
        yield from self._comb(0, ())


C = Comb()
a = [1,2,3,4,5]
print('combination of', a)
print('All subsets:')
for it in C.comb(a):
    print(it)
print('Fixed size subsets:')
for i in range(6):
    print('Size', i)
    for it in C.comb(a,i):
        print(it)

既可以指定n个元素的组合,也可以不指定,计算所有组合。组合就是subset!

下面是回溯算法的另一种写法,优化内存的使用,本质都一样:

def subsets(nums):

    def __dfs(i, lst):
        yield lst[:]
        for j in range(i, len(nums)):
            lst.append(nums[j])
            yield from __dfs(j+1, lst)
            lst.pop()

    yield from __dfs(0, [])


for it in subsets([1,2,3,4]):
    print(it)

输出:

[]
[1]
[1, 2]
[1, 2, 3]
[1, 2, 3, 4]
[1, 2, 4]
[1, 3]
[1, 3, 4]
[1, 4]
[2]
[2, 3]
[2, 3, 4]
[2, 4]
[3]
[3, 4]
[4]

当需要确定组合大小时,

def subsets(nums):

    def __dfs(i, lst):
        if len(lst) == 2:  # control size
            yield lst[:]
            return
        for j in range(i, len(nums)):
            lst.append(nums[j])
            yield from __dfs(j+1, lst)
            lst.pop()

    yield from __dfs(0, [])


for it in subsets([1,2,3,4]):
    print(it)

用BFS搜索计算组合

下面的内容老旧,存档而已

== 以下存档 ==

组合size确定

假设有一个集合A,内有元素N个,不管是否存在相同元素。问题:找出所有的size=M的组合?(\(M<=N\))

使用循环是最简单直接的方法:

def factorial(n):
    r = 1
    for i in range(2,n+1):
        r *= i
    return r

f = factorial
assert f(4) == 24
assert f(5) == 24*5
assert f(6) == 24*5*6
assert f(7) == 24*5*6*7

a = tuple('123456abcdefghijklmn123456')
N = len(a)

s2 = [(a[i],a[j]) for i in range(N)
                  for j in range(i+1,N)]
assert len(s2) == f(N)/(f(N-2)*f(2))

s3 = [(a[i],a[j],a[k]) for i in range(N)
                       for j in range(i+1,N)
                       for k in range(j+1,N)]
assert len(s3) == f(N)/(f(N-3)*f(3))

s4 = [(a[i],a[j],a[k],a[m])
                for i in range(N)
                for j in range(i+1,N)
                for k in range(j+1,N)
                for m in range(k+1,N)]
assert len(s4) == f(N)/(f(N-4)*f(4))

寻找M个元素的组合,就是M重循环。上面代码的assert语句,有个计算组合数的公式,具体请参考:阶乘,排列和组合

如果要得到没有重复元素的确定个数的组合,可以再做进一步处理,做个去重:

s4b = [t for t in s4 if len(set(t))==4]
print(len(s4))   # 14950
print(len(s4b))  # 13309

以上只是单纯的计算所有组合,并没有带任何过滤条件。

组合size不确定

允许重复元素

假设有一个集合A,内有元素N个,不管是否存在相同元素。问题:找出集合A的所有子集。(既然集合A中存在可能存在重复元素,子集也允许重复元素,但重复次数有上限)

一个思路:先计算单个元素的组合,然后2个元素的组合,......,用M-1个元素的组合,计算M个元素的组合,直到最后的N个元素的组合:

def get_subset_dup(lst: list[int]) -> list[tuple[int]]:
    r = r1 = [(v,) for v in set(lst)]
    for i in range(len(lst)-1):
        r2 = []
        for s in r1:
            v = lst[:]
            for t in s:
                v.remove(t)
            for t in v:
                a = tuple(sorted(s+(t,)))
                if a not in r:
                    r.append(a)
                    r2.append(a)
        if not r2:
            break
        r1 = r2
    return r


print(get_subset_dup([1,2,3]))
print(get_subset_dup([1,2,3,3]))
print(get_subset_dup([1,2,2,2]))

输出:

[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (3, 3), (1, 2, 3), (1, 3, 3), (2, 3, 3), (1, 2, 3, 3)]
[(1,), (2,), (1, 2), (2, 2), (1, 2, 2), (2, 2, 2), (1, 2, 2, 2)]

不允许重复元素

稍微调整一下代码:

def get_subset(lst: list[int]) -> list[tuple[int]]:
    r = r1 = [(v,) for v in set(lst)]
    for i in range(len(lst)-1):
        r2 = []
        for s in r1:
            v = [t for t in lst if t not in s]
            for t in v:
                a = tuple(sorted(s+(t,)))
                if a not in r:
                    r.append(a)
                    r2.append(a)
        if not r2:
            break
        r1 = r2
    return r


print(get_subset([]))
print(get_subset([1]))
print(get_subset([1,2]))
print(get_subset([1,2,3]))
print(get_subset([1,2,3,4]))
print(get_subset([1,2,3,3]))
print(get_subset([1,2,2,2]))

输出:

[]
[(1,)]
[(1,), (2,), (1, 2)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (3,), (4,), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (1, 2)]

使用set

def get_subset(lst: list[int]) -> list[tuple[int]]:
    lstset = set(lst)
    r = [set((v,)) for v in lstset]
    for i in range(len(lstset)-1):
        for s in r[:]:
            v = lstset - s
            r += [s|{t} for t in v if s|{t} not in r]
    return r

总结Python的set操作

示例问题

子集的乘积

复杂度

假设集合A的size为100,它的4元素的组合数为:

\(C_{100}^4=\cfrac{100!}{(100-4)!4!}=3921225\)

如果集合A的size比100还要大,这似乎很容易,100并不是一个很大的数。那么,可以想想,计算出来的组合数,或者计算过程中需要保存的中间结果,数量非常庞大。这会成为一个问题!

动态规划

回溯(backtrack)

回溯,backtrack,是另一个思路。我的理解是:回溯将循环拆解为递归,避免保存中间结果。

示例:

回溯的妙处与特点:

不确定size子集(回溯)

前文使用的算法,在计算不确定size的问题时,空间复杂度太高。而使用回溯算法,几乎不需要任何额外空间。

from pprint import pprint

def comb(lst, c):
    yield c[:]
    for it in lst:
        c.append(it)
        idx = lst.index(it)
        yield from comb(lst[idx+1:], c)
        c.pop()

for i in range(5):
    data = [x for x in range(i)]
    cb = [x for x in comb(data,[])]
    cb.remove([])
    print('data', data)
    pprint(cb)

输出:

data []
[]
data [0]
[[0]]
data [0, 1]
[[0], [0, 1], [1]]
data [0, 1, 2]
[[0], [0, 1], [0, 1, 2], [0, 2], [1], [1, 2], [2]]
data [0, 1, 2, 3]
[[0],
 [0, 1],
 [0, 1, 2],
 [0, 1, 2, 3],
 [0, 1, 3],
 [0, 2],
 [0, 2, 3],
 [0, 3],
 [1],
 [1, 2],
 [1, 2, 3],
 [1, 3],
 [2],
 [2, 3],
 [3]]

能使用递归的问题,都是可以巧妙的划分出子问题的问题。

用回溯计算排列

在计算组合的基础上,稍微修改一下,即可计算排列:

from pprint import pprint

def permute(lst, p):
    if not lst:
        yield p[:]
        return
    for it in lst:
        p.append(it)
        yield from permute([x for x in lst if x!=it],p)
        p.pop()

for i in range(1,5):
    data = [i for i in range(i)]
    prt = [x for x in permute(data,[])]
    print(len(prt))
    pprint(prt)

输出:

1
[[0]]
2
[[0, 1], [1, 0]]
6
[[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]]        
24
[[0, 1, 2, 3],
 [0, 1, 3, 2],
 [0, 2, 1, 3],
 [0, 2, 3, 1],
 [0, 3, 1, 2],
 [0, 3, 2, 1],
 [1, 0, 2, 3],
 [1, 0, 3, 2],
 [1, 2, 0, 3],
 [1, 2, 3, 0],
 [1, 3, 0, 2],
 [1, 3, 2, 0],
 [2, 0, 1, 3],
 [2, 0, 3, 1],
 [2, 1, 0, 3],
 [2, 1, 3, 0],
 [2, 3, 0, 1],
 [2, 3, 1, 0],
 [3, 0, 1, 2],
 [3, 0, 2, 1],
 [3, 1, 0, 2],
 [3, 1, 2, 0],
 [3, 2, 0, 1],
 [3, 2, 1, 0]]

非递归回溯(Stack)

递归采用调用栈保存状态,非递归就自己用一个stack来保存,空间复杂度应该基本一样,时间上的关键,就是push和pop操作与函数调用的性能差异。如果对调用栈的深度没有信心,就只能选择非递归的实现。

非递归回溯计算排列

from pprint import pprint

def permute_loop(lst: list[int]|tuple[int]):
    stack = [(lst,())]
    while stack:
        r, p = stack.pop()
        if not r:
            yield p
        else:
            for i,it in enumerate(r):
                stack.append((r[:i]+r[i+1:],p+(it,)))

for i in range(1,5):
    data = [x for x in range(i)]
    p = [x for x in permute_loop(data)]
    print(data, len(p))
    pprint(p)

输出略。

非递归回溯计算所有组合(子集)

def comb_loop(data):
    stack = [(data[i+1:],(x,)) for i,x in enumerate(data)]
    while stack:
        r, c = stack.pop()
        yield c
        for i,x in enumerate(r):
            stack.append((r[i+1:],c+(x,)))

for i in range(5):
    data = tuple(x for x in range(i))
    p = [x for x in comb_loop(data)]
    print(data, len(p))
    pprint(p)

输出略。

本文链接:https://cs.pynote.net/ag/202307241/

-- EOF --

-- MORE --