Rod-Cutting Problem,切钢条问题

Last Updated: 2024-03-25 14:15:01 Monday

-- TOC --

一根钢条长n米,现在要将其切成几段,只能在整数米的位置切,切完之后,不同长度的钢条价值不等。问题是,如何切能够实现最大价值。可以不切,求最大值以及切法。

问题分析

长为\(n\)米的钢条,共有\(n-1\)个可切的位置,因此共有\(2^{n-1}\)种不同的切法。(想象每个切割点都有两种状态,切或不切,这两种状态彼此间互不影响)

解决此问题的一个前提是,必须要先有不同长度钢条的价格表,假设如下:

Length 1 2 3 4 5  6  7  8  9  10
Price  1 5 8 9 10 17 17 20 24 30

每个切割点都有切或不切两种状态,状态相互独立,价值计算独立,完全适用In or Out思路来遍历足有组合。

另一个常见的遍历组合的思路是Cut and Paste,比如当第1刀切下去之后,钢条被分成两段,就是两个子问题,然后分别求解这两个子问题。其实仔细分析就会发现,这个思路没有错,只是遍历了重复的子问题。左边的子问题是重复的。

Optimal Sub-structure: optimal solutions to a problem incorporate optimal solutions to related subproblems, which you may solve independently.

用Python而不用伪代码,是需要考虑一点点实现上的问题的。比如当n大于10的时候,由于价格表中没有大于10米的价格,这表示超过10米就必须切割,没有人购买超过10米的钢条,比如无法运输,非标等等原因。这个限制实际上大大简化了此问题,大大降低了时间复杂度。而用Python的好处是,在保持高度抽象的情况下,还能够将代码运行起来看效果,而不仅仅是数学证明和分析。

最大值

递归

P = [0,1,5,8,9,10,17,17,20,24,30]
assert len(P) == 10+1
print(P)


def rod_cut(n):
    if n == 0:
        return 0
    maxv = 0
    for i in range(1,min(n+1,10+1)):
        maxv = max(maxv, P[i]+rod_cut(n-i))
    return maxv


for i in range(1,13):
    print(i, rod_cut(i))

print(20, rod_cut(20))
print(24, rod_cut(24))

设\(T(n)\)为rod_cut函数的调用次数,如果对递归前的循环不做min的限制,总能循环到n,即最后总能调用到rod_cut(0),此时,可以写出:

$$T(n)=1+\sum_{i=0}^{n-1}T(i)$$

1表示最开始的那次调用。用Induction的方法,可以证明\(T(n)=2^n\)。

而上面的实现是这样的:

$$T(n)=1+\sum_{i=1}^{10}T(i)$$

这就相当于在对geometric series求和时,总是只加总前10项,最后得到的结果是\(<2^n\),但时间复杂度依然是\(O(2^n)\)。这是不可接受的!

Top-Down

Memoization不是拼写错误!

class rodcut:

    P = [0,1,5,8,9,10,17,17,20,24,30]

    def __init__(self):
        self.clear()

    def cut(self, n):
        if n in self.maxd:
            return self.maxd[n]
        maxv = 0
        for i in range(1,min(n+1,10+1)):
            maxv = max(maxv, rodcut.P[i]+self.cut(n-i))
        self.maxd[n] = maxv
        return maxv

    def clear(self):
        self.maxd = {0:0}


rc = rodcut()
for i in (10,100,900):
    rc.clear()
    print(i, rc.cut(i))

n不能太大,Python默认的递归深度为1000。

到底有多少subproblem?Dynamic Programming确保了每一个subproblem只计算一次,长度为n的钢条,一共就n个subproblem。代码中,在解决每个subproblem时,最多只循环了10次。因此,这个算法的时间复杂度是线性的,\(\Theta(n)\)。

Bottom-Up

Bottom Up版本冲破了递归深度的限制,并且在现实中拥有更好的性能,这来自减少了大量的函数调用的开销。

P = [0,1,5,8,9,10,17,17,20,24,30]


def rod_cut_bu(n):
    prev = [0]
    for i in range(1,n+1):
        maxv = 0
        for j in range(1,min(i+1,11)):
            maxv = max(maxv, P[j]+prev[-j])
        prev.append(maxv)
        prev = prev[-11:]  # save memory
    return maxv


for i in (10,100,900,1000,2000):
    print(i, rod_cut_bu(i))

很明显,Bottom Up版本的时间复杂度为\(\Theta(n)\),也是线性的。这是因为本文有个限制,超过10米的钢条必须要切割,因为价格表中没有超过10米的内容。因此inner loop成了constant。如果没有这个限制,价格表无限大,就与《CLRS》中的计算一样了,是quadratic级别。

Subproblem Graph

厘清子问题间的相互关系,画成图,一目了然,用这个图指导Bottom-Up版本的实现。

本题的子问题关系很清晰,n米长的钢条的解,需要n-1到n-10这10个子问题的解来获得。所以,完全可以从底向上一层层计算子问题的解。

如何切

想办法在计算过程中,记录下来切割点的位置。

Top-Down

class rodcut:

    P = [0,1,5,8,9,10,17,17,20,24,30]

    def __init__(self):
        self.clear()

    def cut(self, n):
        if n in self.maxd:
            return self.maxd[n]
        maxv = 0
        for i in range(1,min(n+1,10+1)):
            t = rodcut.P[i] + self.cut(n-i)
            if maxv < t:
                maxv = t
                self.split[n] = i
        self.maxd[n] = maxv
        return maxv

    def clear(self):
        self.maxd = {0:0}
        self.split = {}

    def show_split(self, n):
        i = n
        while i > 0:
            print(self.split[i], end=' ')
            i -= self.split[i]
        print()


rc = rodcut()
for i in (1,2,3,4,5,6,7,8,9,10,32,239):
    print('n = ', i)
    rc.clear()
    rc.cut(i)
    rc.show_split(i)

输出:

n =  1
1
n =  2
2
n =  3
3
n =  4
2 2
n =  5
2 3
n =  6
6
n =  7
1 6
n =  8
2 6
n =  9
3 6
n =  10
10
n =  32
2 10 10 10
n =  239
3 6 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10

其实,只要先按10米单位切就好了,在10米内的,才需要用到这个算法!这个价格表太短了...

Bottom-Up

P = [0,1,5,8,9,10,17,17,20,24,30]


def rod_cut_bu(n):
    prev = [0]
    split = {}
    for i in range(1,n+1):
        maxv = 0
        for j in range(1,min(i+1,11)):
            t = P[j] + prev[-j]
            if maxv < t:
                maxv = t
                split[i] = j
        prev.append(maxv)
        prev = prev[-11:]
    return maxv, split


for i in (1,2,3,4,5,6,7,8,9,10,32,1234):
    m, s = rod_cut_bu(i)
    print('n =', i, 'max =', m)
    j = i
    while j > 0:
        print(s[j], end=' ')
        j -= s[j]
    print()

本文链接:https://cs.pynote.net/ag/dp/202403071/

-- EOF --

-- MORE --