01背包问题:详细解释为什么重量维度必须从大到小遍历。
01背包
问题描述
题目链接:https://www.lanqiao.cn/problems/1174/learning/?page=1&first_category_id=1&problem_id=1174
特点:每件物品只能拿或者不拿。
解法1
设置状态:dp[i][j]指的是前i件物品重量为j的最大价值。
第i件物品,可以有两种状态:
拿: dp[i][j] = dp[i-1][j-wi] + v[i] 保证这件物品能放下:j≥wi
不拿:dp[i][j] = dp[i-1][j]
取上面两者的最大值。
代码:
n, W = map(int, input().split())
w = [0] * (n+1)
v = [0] * (n+1)
# 输入数据
for i in range(n):w[i+1], v[i+1] = map(int, input().split())
# 创建状态
dp = [[0] * (W+1) for _ in range(n+1)]
for i in range(1, n+1):for j in range(1, W + 1):if j >= w[i]:dp[i][j] = max(dp[i-1][j], dp[i-1][j-w[i]] + v[i])else:dp[i][j] = dp[i-1][j]
print(dp[n][W])
解法2——滚动优化数组(二维 )
对于解法1的状态转移方程:dp[i][j] = max(dp[i-1][j], dp[i-1][j-w[i]] + v[i]),看下面的表格,可以发现影响到当前状态的只有上一维度(i-1),其他维度影响不到,所以可以把dp数组压缩为两个维度,可以省略很多空间。
代码:
n, W = map(int, input().split())
w = [0] * (n+1)
v = [0] * (n+1)
# 输入数据
for i in range(n):w[i+1], v[i+1] = map(int, input().split())
# 创建状态:维度压缩为2
dp = [[0] * (W+1) for _ in range(2)]
for i in range(0, n+1):for j in range(1, W + 1):if j >= w[i]:dp[i%2][j] = max(dp[(i-1)%2][j], dp[(i-1)%2][j-w[i]] + v[i])else:dp[i%2][j] = dp[(i-1)%2][j]
print(dp[n%2][W])
解法3——滚动优化数组(一维)
还是先看上面的表格。
可以看到dp[i][j]由上一维度的dp[i-1][j]和dp[i-1][j-w[i]]决定,下一维度(i+1)就与dp[i-1][j]无关,只与dp[i][j]有关,所以还可以直接用dp[i][j]去替代dp[i-1][j],再省去很多空间。
有一个关键点:就是j这一维度必须要从大到小,以下是个人的理解:
将维度压缩到1,此时的状态转移方程:
dp[j] = max(dp[j], dp[j-w[i]])
为了方便说明,我们假设当前j对应的值为3,i对应的值为3,则对应的状态转移方程变为:dp[3] = max(dp[3], dp[3-w[3]])
虽然i这一维度被省略压缩了,看起来好像似乎没影响,但不是这样的,我们压缩后仍然要保证对应的维度是正确的。
j从大到小去遍历,那么此时j=3这一维度被刷新了,变成了3维对应的数据,j=3维前面的j都没有被影响到。当j减少的时候,假设减少1,那么dp[2](此时i的维度还是3,没变),那这个时候,状态转移方程:dp[2] = max(dp[2], dp[2-w[2]]),由于当j的维度为3的时候,只改了j=3对应的数据(对应i=3的数据),其它没变,可以保证dp[[2]]和dp[[2-w[2]]]这两个数据还是对应第2维的,因此可以成立。
如果从小到大的话,同个道理:
dp[j] = max(dp[j], dp[j-w[i]])
我们假设当前j对应的值为3,i对应的值为3,则对应的状态转移方程变为:
dp[3] = max(dp[3], dp[3-w[3]])
那么此时j=3这一维度被刷新了,变成了3维对应的数据,j=3维前面的j都没有被影响到。当j增加的时候,假设增加1,那么dp[[4]](此时i的维度还是3,没变),那这个时候,状态转移方程:dp[4] = max(dp[4], dp[4-w[2]]),由于当j的维度为3的时候,改了j=3对应的数据(对应i=3的数据),其它没变,如果2-w[2]刚好就等于3,那此时dp[4-w[2]]的数据就对应了第3维度的数据,但是我们要的是第2维的,就不行了。
所以必须保证从大到小去遍历j。
代码:
n, W = map(int, input().split())
w = [0] * (n+1)
v = [0] * (n+1)
# 输入数据
for i in range(n):w[i+1], v[i+1] = map(int, input().split())
# 创建状态:维度压缩为1
dp = [0] * (W+1)
for i in range(0, n+1):for j in range(W, w[i]-1, -1):dp[j] = max(dp[j], dp[j-w[i]] + v[i])
print(dp[W])