Last Updated: 2024-03-11 13:05:55 Monday
-- TOC --
由于可能的搜索空间巨大,一般应用于具体问题时,要根据情况做Pruning。
下面是一段计算组合的示例代码:
class Comb:
def _comb(self, idx, t):
""" O(2^N)
If there are p different values in each position,
the complexity would be O(p^N).
"""
if idx == self.length:
yield t
return
yield from self._comb(idx+1, t+(self.lst[idx],))
yield from self._comb(idx+1, t)
def _combn(self, m, i, t):
""" O(N^m) """
if m == self.m:
yield t
return
while i < self.length:
yield from self._combn(m+1, i+1, t+(self.lst[i],))
i += 1
def comb(self, lst, m=None):
""" Yield combinations/subsets of lst.
If m is None, yield all subsets.
"""
self.lst = lst
self.length = len(lst)
self.m = m
if m is not None:
assert self.length >= m
yield from self._combn(0, 0, ())
return
yield from self._comb(0, ())
C = Comb()
a = [1,2,3,4,5]
print('combination of', a)
print('All subsets:')
for it in C.comb(a):
print(it)
print('Fixed size subsets:')
for i in range(6):
print('Size', i)
for it in C.comb(a,i):
print(it)
既可以指定n个元素的组合,也可以不指定,计算所有组合。组合就是subset!
下面是回溯算法的另一种写法,优化内存的使用,本质都一样:
def subsets(nums):
def __dfs(i, lst):
yield lst[:]
for j in range(i, len(nums)):
lst.append(nums[j])
yield from __dfs(j+1, lst)
lst.pop()
yield from __dfs(0, [])
for it in subsets([1,2,3,4]):
print(it)
输出:
[]
[1]
[1, 2]
[1, 2, 3]
[1, 2, 3, 4]
[1, 2, 4]
[1, 3]
[1, 3, 4]
[1, 4]
[2]
[2, 3]
[2, 3, 4]
[2, 4]
[3]
[3, 4]
[4]
当需要确定组合大小时,
def subsets(nums):
def __dfs(i, lst):
if len(lst) == 2: # control size
yield lst[:]
return
for j in range(i, len(nums)):
lst.append(nums[j])
yield from __dfs(j+1, lst)
lst.pop()
yield from __dfs(0, [])
for it in subsets([1,2,3,4]):
print(it)
下面的内容老旧,存档而已
假设有一个集合A,内有元素N个,不管是否存在相同元素。问题:找出所有的size=M的组合?()
使用循环是最简单直接的方法:
def factorial(n):
r = 1
for i in range(2,n+1):
r *= i
return r
f = factorial
assert f(4) == 24
assert f(5) == 24*5
assert f(6) == 24*5*6
assert f(7) == 24*5*6*7
a = tuple('123456abcdefghijklmn123456')
N = len(a)
s2 = [(a[i],a[j]) for i in range(N)
for j in range(i+1,N)]
assert len(s2) == f(N)/(f(N-2)*f(2))
s3 = [(a[i],a[j],a[k]) for i in range(N)
for j in range(i+1,N)
for k in range(j+1,N)]
assert len(s3) == f(N)/(f(N-3)*f(3))
s4 = [(a[i],a[j],a[k],a[m])
for i in range(N)
for j in range(i+1,N)
for k in range(j+1,N)
for m in range(k+1,N)]
assert len(s4) == f(N)/(f(N-4)*f(4))
寻找M个元素的组合,就是M重循环。上面代码的assert语句,有个计算组合数的公式,具体请参考:阶乘,排列和组合。
如果要得到没有重复元素的确定个数的组合,可以再做进一步处理,做个去重:
s4b = [t for t in s4 if len(set(t))==4]
print(len(s4)) # 14950
print(len(s4b)) # 13309
以上只是单纯的计算所有组合,并没有带任何过滤条件。
假设有一个集合A,内有元素N个,不管是否存在相同元素。问题:找出集合A的所有子集。(既然集合A中存在可能存在重复元素,子集也允许重复元素,但重复次数有上限)
一个思路:先计算单个元素的组合,然后2个元素的组合,......,用M-1个元素的组合,计算M个元素的组合,直到最后的N个元素的组合:
def get_subset_dup(lst: list[int]) -> list[tuple[int]]:
r = r1 = [(v,) for v in set(lst)]
for i in range(len(lst)-1):
r2 = []
for s in r1:
v = lst[:]
for t in s:
v.remove(t)
for t in v:
a = tuple(sorted(s+(t,)))
if a not in r:
r.append(a)
r2.append(a)
if not r2:
break
r1 = r2
return r
print(get_subset_dup([1,2,3]))
print(get_subset_dup([1,2,3,3]))
print(get_subset_dup([1,2,2,2]))
输出:
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (3, 3), (1, 2, 3), (1, 3, 3), (2, 3, 3), (1, 2, 3, 3)]
[(1,), (2,), (1, 2), (2, 2), (1, 2, 2), (2, 2, 2), (1, 2, 2, 2)]
稍微调整一下代码:
def get_subset(lst: list[int]) -> list[tuple[int]]:
r = r1 = [(v,) for v in set(lst)]
for i in range(len(lst)-1):
r2 = []
for s in r1:
v = [t for t in lst if t not in s]
for t in v:
a = tuple(sorted(s+(t,)))
if a not in r:
r.append(a)
r2.append(a)
if not r2:
break
r1 = r2
return r
print(get_subset([]))
print(get_subset([1]))
print(get_subset([1,2]))
print(get_subset([1,2,3]))
print(get_subset([1,2,3,4]))
print(get_subset([1,2,3,3]))
print(get_subset([1,2,2,2]))
输出:
[]
[(1,)]
[(1,), (2,), (1, 2)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (3,), (4,), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (1, 2, 3), (1, 2, 4), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)]
[(1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
[(1,), (2,), (1, 2)]
使用set
def get_subset(lst: list[int]) -> list[tuple[int]]:
lstset = set(lst)
r = [set((v,)) for v in lstset]
for i in range(len(lstset)-1):
for s in r[:]:
v = lstset - s
r += [s|{t} for t in v if s|{t} not in r]
return r
假设集合A的size为100,它的4元素的组合数为:
如果集合A的size比100还要大,这似乎很容易,100并不是一个很大的数。那么,可以想想,计算出来的组合数,或者计算过程中需要保存的中间结果,数量非常庞大。这会成为一个问题!
回溯,backtrack,是另一个思路。我的理解是:回溯将循环拆解为递归,避免保存中间结果。
示例:
回溯的妙处与特点:
前文使用的算法,在计算不确定size的问题时,空间复杂度太高。而使用回溯算法,几乎不需要任何额外空间。
from pprint import pprint
def comb(lst, c):
yield c[:]
for it in lst:
c.append(it)
idx = lst.index(it)
yield from comb(lst[idx+1:], c)
c.pop()
for i in range(5):
data = [x for x in range(i)]
cb = [x for x in comb(data,[])]
cb.remove([])
print('data', data)
pprint(cb)
输出:
data []
[]
data [0]
[[0]]
data [0, 1]
[[0], [0, 1], [1]]
data [0, 1, 2]
[[0], [0, 1], [0, 1, 2], [0, 2], [1], [1, 2], [2]]
data [0, 1, 2, 3]
[[0],
[0, 1],
[0, 1, 2],
[0, 1, 2, 3],
[0, 1, 3],
[0, 2],
[0, 2, 3],
[0, 3],
[1],
[1, 2],
[1, 2, 3],
[1, 3],
[2],
[2, 3],
[3]]
能使用递归的问题,都是可以巧妙的划分出子问题的问题。
在计算组合的基础上,稍微修改一下,即可计算排列:
from pprint import pprint
def permute(lst, p):
if not lst:
yield p[:]
return
for it in lst:
p.append(it)
yield from permute([x for x in lst if x!=it],p)
p.pop()
for i in range(1,5):
data = [i for i in range(i)]
prt = [x for x in permute(data,[])]
print(len(prt))
pprint(prt)
输出:
1
[[0]]
2
[[0, 1], [1, 0]]
6
[[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0]]
24
[[0, 1, 2, 3],
[0, 1, 3, 2],
[0, 2, 1, 3],
[0, 2, 3, 1],
[0, 3, 1, 2],
[0, 3, 2, 1],
[1, 0, 2, 3],
[1, 0, 3, 2],
[1, 2, 0, 3],
[1, 2, 3, 0],
[1, 3, 0, 2],
[1, 3, 2, 0],
[2, 0, 1, 3],
[2, 0, 3, 1],
[2, 1, 0, 3],
[2, 1, 3, 0],
[2, 3, 0, 1],
[2, 3, 1, 0],
[3, 0, 1, 2],
[3, 0, 2, 1],
[3, 1, 0, 2],
[3, 1, 2, 0],
[3, 2, 0, 1],
[3, 2, 1, 0]]
递归采用调用栈保存状态,非递归就自己用一个stack来保存,空间复杂度应该基本一样,时间上的关键,就是push和pop操作与函数调用的性能差异。如果对调用栈的深度没有信心,就只能选择非递归的实现。
非递归回溯计算排列
from pprint import pprint
def permute_loop(lst: list[int]|tuple[int]):
stack = [(lst,())]
while stack:
r, p = stack.pop()
if not r:
yield p
else:
for i,it in enumerate(r):
stack.append((r[:i]+r[i+1:],p+(it,)))
for i in range(1,5):
data = [x for x in range(i)]
p = [x for x in permute_loop(data)]
print(data, len(p))
pprint(p)
输出略。
非递归回溯计算所有组合(子集)
def comb_loop(data):
stack = [(data[i+1:],(x,)) for i,x in enumerate(data)]
while stack:
r, c = stack.pop()
yield c
for i,x in enumerate(r):
stack.append((r[i+1:],c+(x,)))
for i in range(5):
data = tuple(x for x in range(i))
p = [x for x in comb_loop(data)]
print(data, len(p))
pprint(p)
输出略。
本文链接:https://cs.pynote.net/ag/202307241/
-- EOF --
-- MORE --