GSS3 - Can you answer these queries III 动态DP题解
起因是知乎推给我了这么一篇文章:浅谈动态dp,看了之后感觉还挺好理解的,于是试图做一做上面列出的题,然后写完之后交上去WA了,随便试了几个用例错的离谱,再仔细一看......原来是这篇文章给的DP定义有错误或者说并不能真正拿来做题,难受了
所以这篇文章以GSS3 - Can you answer these queries III一题为例,给出一个正确的动态dp做法(大多数人用线段树直接做的,也挺好维护)
仍然定义\(dp[i]\)表示以\(i\)结尾的最大子段和,定义\(dp2[i]\)为\(...dp[i-1],dp[i]\)之中的最大值,那么有:
\[dp[i]=max(dp[i-1]+a[i],a[i]) \]\[dp2[i]=max(dp[i],dp2[i-1])=max(dp[i-1]+a[i],a[i],dp2[i-1]) \]所以按照动态dp的形式来写就有:
(EXCEL真好用.jpg)
记\(v_i\)和\(m_i\)分别为第\(i\)个数字对应的列向量和转移矩阵,则区间\([l,r]\)的答案为\(dp2[i]\),所以我们要计算\(v_r\),按照
\[v_{r}=m_{r}*m_{r-1}*...*m_{l+1}*v_{l} \]之中\(v_{l}=(a[l],a[l],0)\),然后就可以线段树维护矩阵乘法+单点修改了
特别注意,矩阵乘法不满足交换律,所以线段树里在写的时候一定是右儿子的矩阵乘以左儿子矩阵而不能反过来!
代码:
#include
using namespace std;
#define rep(i,a,n) for(int i=a;i<=n;i++)
#define per(i,a,n) for(int i=n;i>=a;i--)
#define pb push_back
#define SZ(x) ((int)(x).size())
#define fastin ios::sync_with_stdio(0);cin.tie(0);cout.tie(0)
typedef long long ll;
typedef pair pii;
typedef double db;
const ll inf1=-3e18,inf2=-1e18;
// inf2 for (-inf) in formula
struct Matrix{
int n; //n行 m列
static const int sz=5;
ll v[sz][sz];
Matrix(int n):n(n){
rep(i,1,n)rep(j,1,n) v[i][j]=inf1;
};
Matrix operator* (const Matrix B) const {
assert(n==B.n);
Matrix C(n);
for(int i = 1;i <= n;i++)
for(int j = 1;j <= n;j++)
for(int k = 1;k <= n;k++){
C.v[i][j] = max(C.v[i][j],v[i][k]+B.v[k][j]);
}
return C;
}
Matrix quickpow(Matrix A,ll p){
Matrix ans(n);
while(p>0){
if(p&1) ans=A*ans;
A=A*A;
p>>=1;
}
return ans;
}
};
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
struct node{
int l,r;
Matrix val;
node():val(3){};
};
int b[50010];
node a[200010];
inline void upd(int now){
a[now].val=a[rs(now)].val*a[ls(now)].val;
}
void build(int now,int l,int r){
a[now].l=l;a[now].r=r;
if(l==r){
Matrix tmp(3);
tmp.v[1][1]=tmp.v[1][3]=b[l];
tmp.v[2][1]=tmp.v[2][3]=b[l];
tmp.v[3][1]=tmp.v[3][2]=tmp.v[1][2]=inf2;
tmp.v[2][2]=tmp.v[3][3]=0;
a[now].val=tmp;
return;
}
int mid=(l+r)>>1;
build(ls(now),l,mid);
build(rs(now),mid+1,r);
upd(now);
}
Matrix query(int l,int r,int now=1){
assert(r>=l);
int nowl=a[now].l,nowr=a[now].r;
int mid=(nowl+nowr)>>1;
if(nowl==l&&nowr==r){
return a[now].val;
}
if(mid>=r){
return query(l,r,ls(now));
}
else if(mid>1;
if(x<=mid) change(ls(now),x,v);
else change(rs(now),x,v);
upd(now);
}
int n,m,l,r,cmd;
int main(){
scanf("%d",&n);
rep(i,1,n) scanf("%d",&b[i]);
build(1,1,n);
scanf("%d",&m);
rep(i,1,m){
scanf("%d%d%d",&cmd,&l,&r);
if(!cmd){
change(1,l,r);
b[l]=r;
}
else{
if(l==r){
printf("%d\n",b[l]);
}
else{
Matrix tmp=query(l+1,r);
int ans=max( b[l]+max(tmp.v[2][1], tmp.v[2][2]), tmp.v[2][3]);
ans=max(ans,b[l]);
printf("%d\n",ans);
}
}
}
return 0;
}