Subset Sum,子集和问题

Last Updated: 2024-03-20 08:32:54 Wednesday

-- TOC --

子集和问题是经典的动态规划问题,也是著名的NPC问题。

问题分析

给定一个数据集和一个目标值,问是否存在一个这个数据集的子集,满足此子集中的所有元素的和等于目标值。为了让生活简单一点,我们限定数据集和目标值都是正整数。如果允许负数或零,subset sum问题的求解会更加复杂。

递归

def subset_sum(lst, target):
    if not lst or target<0:
        return False
    if target == 0:
        return True
    a = subset_sum(lst[1:], target-lst[0])
    return a or subset_sum(lst[1:], target)


print(subset_sum([1,2,3,4,5], 6))  # True
print(subset_sum([1,2,3,4,5], 99)) # False

用递归遍历所有可能的组合,这个遍历手法,与回溯很相似。每个元素都有两个状态,在子集内或不在子集内。复杂度为:

$$T(n)=2T(n-1)+1=O(2^n)$$

下面换一种写法,不需要在每次递归的时候创建新的list,这更像回溯了:

class subset_sum:

    def prep(self, lst):
        self.lst = lst
        self.length = len(lst)

    def _sum(self, idx, target):
        if idx==self.length or target<0:
            return False
        if target == 0:
            return True
        a = self._sum(idx+1, target-self.lst[idx])
        return a or self._sum(idx+1, target)

    def sum(self, target):
        return self._sum(0, target)


a = [1,2,3,4,5,6]
ssp = subset_sum()
ssp.prep(a)
print(ssp.sum(6))
ssp.prep(a)
print(ssp.sum(99))

Dynamic Programming

Top-Down

class subset_sum:

    def prep(self, lst):
        self.lst = lst
        self.length = len(lst)
        self.rec = {}

    def _sum(self, idx, target):
        if (idx,target) in self.rec:
            return self.rec[(idx,target)]
        if target == 0:
            return True
        if idx==self.length or target<0:
            return False
        a = self._sum(idx+1, target-self.lst[idx])
        if a:
            self.rec[idx,target] = True
            return True
        b = self._sum(idx+1, target)
        self.rec[idx,target] = b
        return b

    def sum(self, target):
        return self._sum(0, target)


a = [1,2,3,4,5,6,7,8,9]
ssp = subset_sum()
ssp.prep(a)
print(ssp.sum(16))
ssp.prep(a)
print(ssp.sum(99))

复杂度:\(O(n\cdot sum)\)

子问题的数量,不仅仅由n决定,数据集中元素的值,以及target的值,都与子问题的数量有关系。这里是subset sum问题有点特殊的地方!如果\(sum=2^n\)呢!...

对于判定性问题,deterministic problem,有一些技巧可以节省内存开销(当然,具体实现时不一定会使用内存,可能后面是一整个数据库呢):

class subset_sum:

    def prep(self, lst):
        self.lst = lst
        self.length = len(lst)
        self.rec = set()

    def _sum(self, idx, target):
        if (idx,target) in self.rec:
            return True
        if target == 0:
            return True
        if idx==self.length or target<0:
            return False
        a = self._sum(idx+1, target-self.lst[idx])
        if a:
            self.rec.add((idx,target))
            return True
        b = self._sum(idx+1, target)
        if b:
            self.rec.add((idx,target))
        return b

    def sum(self, target):
        return self._sum(0, target)

Bottom-Up

学到一种组合方法:

写出[1,2,3]的所有组合:
[[]], 初始状态
[[],[1]],加入1
[[],[1],[2],[1,2]],加入2
[[],[1],[2],[1,2],[3],[1,3],[2,3],[1,2,3]],加入3
每增加一个元素,组合数翻倍!

按这个思路,实现一个Bottom-Up方案:

def subset_sum_bu(lst, target):
    ss = [0]
    for it in lst:
        t = []
        for s in ss:
            t.append(s+it)
        ss += t
        if target in ss:
            return True
    return False


a = [1,2,3,4]
print(subset_sum_bu(a, 8))
a = [1,2,3,4]
print(subset_sum_bu(a, 119))

时间和空间复杂度,都是\(O(2^n)\)。尽然用非递归的方式,实现了指数级的增长!当然还可以优化这个思路,比如大于target的值就不要加入ss....不过,这不是DP的Bottom-Up算法,它还是暴力的,做了太多加法...

DP的思路,一定要充分利用Subproblems:

def subset_sum_bu(lst, target):
    rec = [[True]+[False]*target for _ in range(len(lst))]

    if lst[0] == target:
        return True
    if lst[0] < target:
        rec[0][lst[0]] = True

    for i,v in enumerate(lst[1:],1):
        for s in range(1,target+1):
            if v > s:
                rec[i][s] = rec[i-1][s]
            else:
                rec[i][s] = rec[i-1][s] or rec[i-1][s-v]

    return rec[len(lst)-1][target]


a = [1,2,3,4]
print(subset_sum_bu(a, 8))
a = [1,2,3,4]
print(subset_sum_bu(a, 88))
a = [1,2,3,4,5,6,7,8,9]
print(subset_sum_bu(a, 34))

很多时候Bottom-Up都有一个优化内存的思路,在一层层往上计算过程中,并不需要保存那些已经没用的子问题的结果。

def subset_sum_bu(lst, target):
    rec = [False] * (target+1)
    rec[0] = True

    if lst[0] == target:
        return True
    if lst[0] < target:
        rec[lst[0]] = True

    nrec = rec[:]
    for i,v in enumerate(lst[1:],1):
        for s in range(1,target+1):
            if v > s:
                nrec[s] = rec[s]
            else:
                nrec[s] = rec[s] or rec[s-v]
        rec = nrec[:]

    return rec[target]

找出组合

如果可以有负数和0

尝试了一下如下代码,可以解决有负数和0的情况:

class subset_sum_all:

    def _sum(self, idx, curr):
        if curr == self.target:
            return True
        if idx == self.length:
            return False
        a = self._sum(idx+1, curr+self.lst[idx])
        return a or self._sum(idx+1, curr)

    def sum(self, lst, target):
        self.lst = lst
        self.length = len(lst)
        self.target = target
        for i,it in enumerate(lst):
            if self._sum(i+1, lst[i]):
                return True
        return False


a = [1,2,3,4]
ssp = subset_sum_all()
print(ssp.sum(a, 8))
a = [1,2,3,4]
ssp = subset_sum_all()
print(ssp.sum(a, 0))
a = [-1,2,-3,4]
ssp = subset_sum_all()
print(ssp.sum(a, -2))
a = [-1,2,-3,4]
ssp = subset_sum_all()
print(ssp.sum(a, 0))

这个才是真宗的\(O(2^n)\)!和不再是递增的,和是一种unbounded的状态....限制为正整数简化了问题,也能够应用一些技巧进行剪枝。

本文链接:https://cs.pynote.net/ag/dp/202403082/

-- EOF --

-- MORE --