【图论】AcWing 369. 北大ACM队的远足(DAG 必须边 + 双指针)
这题是在我这两天身体不太舒服的情况下写的,写的比较折磨。。也许这道题同时让我更不适了吧
但是写完提交上去竟然直接过了,有点出乎意料。如果是数据水了或者找到 hack 数据可以发过来
分析
这题我们考虑将问题进行拆解:
- 首先,我们需要找出 DAG 的必须边(桥),DAG 上找必须边还是很简单的:对于一条边 \((u, v)\),记 \(fs[x]\) 为从点 \(S\) 出发(沿正方向)到 \(x\) 的方案数,\(ft[x]\) 为从点 \(T\) 出发(反图)到 \(x\) 的方案数,那么这条边为必须边当且仅当 \(fs[u]\times ft[u] = fs[T]\)。而对于具体实现来说,因为 \(fs,ft\) 可能非常大,因此我们可以对 \(fs,ft\) 的值进行取模(原理同 hash),保险起见,我在实现中取了两种模数。
- 不难发现,所有的桥必然在同一条链(当然这个链可以有多种)上,而基于贪心的思想,我们肯定是想让桥尽可能地接近,这样能够使答案最小化。因此,我们只需要将任意一条 \(S\to T\) 的最短路拿出来就好了,显然,所有 \(S\to T\) 的最短路在以第一条桥开始,最后一条桥结束的部分是完全一样的。
- 最后的问题就是:给你 \(n\)(上述最短路长度)段区间,其中 \(m\) 段具有代价, 让你用 \(2\) 段长度为 \(len\) 的区间覆盖尽可能长的具有代价的段以使答案最小化。
- 对于这个问题,分两种情况讨论:
- 选取的两个区间断点为 \(i\),\(ds[i]\) 表示从左边开始,以第 \(i\) 个区间的右端点为右端点,长度为 \(len\) 的区间能覆盖的最大长度;\(dt[i]\) 表示从右边开始,以第 \(i\) 个区间的左端点为左端点,长度为 \(len\) 的区间能覆盖的最大长度,那么这种情况的覆盖最长长度为 \(\max (ds[i] + dt[i+1])\)。
- 上面的做法是具有断点的时候的最优解,而没有断点的时候,可以看作是 \(2len\) 的区间,进行一次覆盖即可,即求:以第 \(i\) 个区间的右端点为右端点,长度为 \(2len\) 的区间能覆盖的最大长度。
// Problem: 北大ACM队的远足
// Contest: AcWing
// URL: https://www.acwing.com/problem/content/description/371/
// Memory Limit: 256 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include
using namespace std;
#define debug(x) cerr << #x << ": " << (x) << endl
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define dwn(i,a,b) for(int i=(a);i>=(b);i--)
#define pb push_back
#define all(x) (x).begin(), (x).end()
#define x first
#define y second
using pii = pair;
using ll = long long;
inline void read(int &x){
int s=0; x=1;
char ch=getchar();
while(ch<'0' || ch>'9') {if(ch=='-')x=-1;ch=getchar();}
while(ch>='0' && ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
x*=s;
}
const int N=1e5+5, M=4e5+5;
int n, m, S, T, len;
struct Edge{
int to, w, next;
}e[M];
int h[N], rh[N], tot;
void add(int *h, int u, int v, int w){
e[tot].to=v, e[tot].w=w, e[tot].next=h[u], h[u]=tot++;
}
int din[N], rdin[N];
void init(){
memset(din, 0, sizeof din);
memset(rdin, 0, sizeof rdin);
memset(h, -1, sizeof h);
memset(rh, -1, sizeof rh);
tot=0;
}
int d[N];
const int P1=1e9+7, P2=998244353, INF=0x3f3f3f3f;
pii fs[N], ft[N];
pii pre[N];
bool get_cnt(){
memset(d, 0x3f, sizeof d);
memset(fs, 0, sizeof fs);
memset(pre, 0, sizeof pre);
queue q;
rep(i,1,n) if(!din[i]) q.push(i), d[i]=0;
fs[S]={1, 1};
while(q.size()){
int u=q.front(); q.pop();
for(int i=h[u]; ~i; i=e[i].next){
int go=e[i].to;
if(d[go]>d[u]+e[i].w){
d[go]=d[u]+e[i].w;
pre[go]={u, e[i].w};
}
(fs[go].x+=fs[u].x)%=P1;
(fs[go].y+=fs[u].y)%=P2;
if(--din[go]==0) q.push(go);
}
}
if(!fs[T].x && !fs[T].y) return false;
memset(ft, 0, sizeof ft);
assert(q.empty());
rep(i,1,n) if(!rdin[i]) q.push(i);
ft[T]={1, 1};
while(q.size()){
int u=q.front(); q.pop();
for(int i=rh[u]; ~i; i=e[i].next){
int go=e[i].to;
(ft[go].x+=ft[u].x)%=P1;
(ft[go].y+=ft[u].y)%=P2;
if(--rdin[go]==0) q.push(go);
}
}
return true;
}
using pib = pair;
vector path;
bool chk(int u, int v){
int val1=1LL*fs[u].x*ft[v].x%P1;
int val2=1LL*fs[u].y*ft[v].y%P2;
return pii(val1, val2)==fs[T];
}
void work_pre(int u){
if(!u) return;
auto [p, w]=pre[u];
work_pre(p);
if(p) path.pb({w, chk(p, u)});
}
void get_path(){
path.clear();
work_pre(T);
}
int ds[N], dt[N];
void solve(){
int sz=path.size();
path.insert(begin(path), {0, 0});
int res, sum=0;
vector sp;
for(auto [x, y]: path) sp.pb(x);
rep(i,1,sz){
sp[i]+=sp[i-1];
if(path[i].y) sum+=path[i].x;
}
int j=1, del=0, ma=0;
rep(i,1,sz){
if(path[i].y) del+=path[i].x;
while(sp[i]-sp[j-1]>2*len){
if(path[j].y) del-=path[j].x;
j++;
}
ma=max(ma, del+(path[j-1].y? 2*len-(sp[i]-sp[j-1]): 0));
}
res=sum-ma;
memset(ds, 0, sizeof ds);
memset(dt, 0, sizeof dt);
j=1, del=0;
rep(i,1,sz){
if(path[i].y) del+=path[i].x;
while(sp[i]-sp[j-1]>len){
if(path[j].y) del-=path[j].x;
j++;
}
ds[i]=max(ds[i-1], del+(path[j-1].y? len-(sp[i]-sp[j-1]): 0));
}
j=sz, del=0;
dwn(i,sz,1){
if(path[i].y) del+=path[i].x;
while(sp[j]-sp[i-1]>len){
if(path[j].y) del-=path[j].x;
j--;
}
dt[i]=max(dt[i+1], del+(j+1<=sz && path[j+1].y? len-(sp[j]-sp[i-1]): 0));
}
rep(i,1,sz) res=min(res, sum-(ds[i]+(i+1<=sz? dt[i+1]: 0)));
cout<>cs;
while(cs--){
cin>>n>>m>>S>>T>>len;
S++, T++;
init();
rep(i,1,m){
int u, v, w; read(u), read(v), read(w);
u++, v++;
add(h, u, v, w), add(rh, v, u, w);
din[v]++, rdin[u]++;
}
if(!get_cnt()){
puts("-1");
continue;
}
get_path();
solve();
}
return 0;
}