The 2025 ICPC Asia Chengdu Regional Contest

B

現在有\(n\)個人(\(n\leq 6\)),每個人有一個傷害值\(a_i\)和魔力消耗\(c_i\),在一個回合中,總共可以使用魔力值為\(m\),每一回合的魔力值都會重置為\(m\),如果上一回合使用了第\(i\)個人,那麼這一回合再使用第\(i\)個人的魔力消耗為\(c_i+k\),求\(R\)回合能夠造成的最大傷害。

  • \(R\leq 1e9\)

其實可以很容易的寫出\(f_{i,s}=\max_{\sum_{i\in s}{a_i}+popcount(s\&t)\times k\leq m}{(f_{i-1,t}+A_s)}\),其中\(A_s\)表示選擇人的狀態為\(s\)能夠造成的傷害,即\(A_s=\sum_{i\in s}{a_i}\)。總狀態數為\(R\times 2^n\),每次轉移的時間複雜度為\(O(2^n)\),總時間複雜度為\(O(4^n\times R)\),是會超時的。

考慮狀態\(f_{i,s}\),它只有前一輪的狀態\(t\)決定,因為我們可以預處理\(G[i][j]\),表示上一輪狀態為\(i\),這一輪狀態為\(j\)時增加的傷害。改變上面的式子為\(f_{i,s}=\max{f_{i-1,t}+G[t][s]}\),我們怎樣快速的求出\(f_{i,s}\)呢?可以使用矩陣快速冪。

  • \(dp^{1}[i][j]\)表示初始狀態\(i\),結束狀態為\(j\)經過一回合的最大傷害;
  • \(dp^{2}[i][j]\)表示經過兩個回合的最大傷害,這個我們可以通過枚舉中間狀態\(k\),\(dp^2[i][j]=dp^1[i][k]\times dp^1[k][j]\),也就是枚舉中間經過的這一輪,因為\(dp^1[i][j]\)我們在前面已經求過了;
  • \(dp^4[i][j]\)則可以通過\(dp^2[i][k]\times dp^2[k][j]\)來求得;

所以我們可以用矩陣快速冪來優化\(dp\),使得最終複雜度為\(O(8^n\times \log{R})\)。

#include <bits/stdc++.h>
using namespace std;
#define inf 1e18
#define endl '\n'
#define int long long
typedef  long long ll;
typedef pair<int, int> pii;
int dx[4] = {1, 0, -1, 0}, dy[4] = {0, 1, 0, -1};
const int N = 2e5 + 9, M = 2e5 + 9, mod = 1e9 + 7;
vector<vector<int>> operator*(vector<vector<int>> &A,vector<vector<int>> &B){
	int n=A.size();
	vector<vector<int>> res(n,vector<int>(n));
	for(int k=0;k<n;k++){
		for(int i=0;i<n;i++){
			for(int j=0;j<n;j++){
				res[i][j]=max(res[i][j],A[i][k]+B[k][j]);
			}
		}
	}
	return res;
}
void solve() {
	int n,m,k,R;
	cin >> n >> m >> k >> R;
	vector<int> a(n+1),c(n+1);
	for(int i=1;i<=n;i++){
		cin >> a[i] >> c[i];
	}
	//預處理M[i][j]表示狀態i到j的增量
	vector<vector<int>> M(1<<n,vector<int>(1<<n,0));
	for(int i=0;i<(1<<n);i++){
		for(int j=0;j<(1<<n);j++){
			int suma=0,sumc=0;
			for(int t=0;t<n;t++){
				if(j>>t&1){
					suma+=a[t+1];
					sumc+=c[t+1];
					if(i>>t&1) sumc+=k;
				}
			}
			if(sumc<=m){
				M[i][j]=suma;
			}
		}
	}
	vector<vector<int>> Mr(1<<n,vector<int>(1<<n,0));
	while(R){
		if(R&1) Mr=Mr*M;
		M=M*M;
		R>>=1;
	}
	int ans=0;
	for(int i=0;i<(1<<n);i++){
		for(int j=0;j<(1<<n);j++){
			ans=max(ans,Mr[i][j]);
		}
	}
	cout << ans << endl;
}
/*

*/
signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	int t = 1;
	cin >> t;
	while (t--) {
		solve();
	}
	return 0;
}

L

給定一棵有根樹,現在每個節點有兩個屬性\(a_i\)和\(b_i\),對於\(u\)的一棵子樹,我們可以交換子樹上的\(a\),來讓最後\(u\)的子樹上的每個節點的\(a_i=b_i\),特別地,如果\(a_i=0\or b_i=0\),那麼也是可以配對的,也就是説\(0\)是通配符。現在我們要獨立的求出每棵子樹是否都可以通過交換操作讓\(a_i\)和\(b_i\)配對,交換操作是獨立的,也就是不會影響另一棵子樹的求解。

首先因為是可以任意交換\(a\)的,因此\(b\)可以對應\(a\)的任何一個排列,所以相當於\(a\)和\(b\)都可以交換。

考慮求解\(u\)這棵子樹的答案,我們需要對這棵子樹維護一個\(cnt\)數組以及一個\(sum\),用來記錄值為\(cnt_i\)的個數,具體操作就是:

  • 對於\(a_j=i\),\(cnt_i:=cnt_i+1\)。維護\(sum\)時,我們根據\(cnt_i\)的大小來判斷,如果\(cnt_i\geq 0\),那麼\(sum:=sum+1\),否則\(sum:=sum-1\);
  • 對於\(b_j=i\),\(cnt_i:=cnt_i-1\)。維護\(sum\)時,如果\(cnt_i> 0\),那麼\(sum:=sum-1\),否則\(sum:=sum+1\);

特別地,對於\(i=0\),我們只執行\(cnt_0:=cnt_0+1,sum:=sum+1\)。

這樣一棵子樹是否可以完全匹配,可以通過判斷\(sum-cnt_0\leq cnt_0\)來判斷,也就是非0的個數要小於0的個數,那麼也就可以完全匹配。

檢查一棵子樹的時間複雜度為\(O(n)\),如果對每個節點都\(dfs\)一次,時間複雜度變成\(O(n^2)\),不能接收,我們考慮檢查完子樹後,同時把信息上傳,這就可以用到樹上啓發式合併,也是新學的內容,非常的叼,對於需要合併子樹信息的,可以達到時間複雜度\(O(n\log{n})\),適用於只詢問,不修改

樹上啓發式合併的操作流程:

  1. 重鏈剖分,求出重兒子;
  2. \(dfs(u,keep)\),表示操作子樹\(u\),如果\(keep=0\),那麼要撤銷子樹\(u\)的影響;否則保留子樹\(u\)的信息。
  • 先訪問\(u\)的輕兒子;
  • 再訪問\(u\)的重兒子;
  • 加上\(u\)節點自己的貢獻;
  • 再次訪問\(u\)的輕兒子,可以使用一個\(add\)函數,專門來加上貢獻;
  • 求解\(u\)的答案;
  • 根據\(keep\),判斷\(u\)的信息是否保留,如果撤銷,專門用函數\(del\)來撤銷;
  1. 有一個性質是,我們在撤銷輕兒子貢獻的時候,可以直接對\(cnt[a[u]]:=0,cnt[b[u]]:=0,sum:=0\),因為在訪問輕兒子的時候,\(cnt\)數組一定是空的。

這裏可以淺談一下為什麼時間複雜度會是\(O(n\log{n})\),根據上面操作,我們看到多出的部分主要是再次訪問輕兒子以及撤銷輕兒子的貢獻,這兩個是互逆的,所以我們就看再次訪問輕兒子的次數。


我們看8號節點,開始自己是輕兒子,那麼需要撤銷一次,當回溯到4號節點時,因為4是輕兒子,所以8又要被撤銷一次,遇到3被撤銷一次,遇到1被撤銷一次。我們可以看到只有當遇到輕兒子的時候,節點才會被撤銷,但是這個輕兒子和旁邊重兒子合併,相當於節點數量至少是輕兒子子樹大小的兩倍,那麼也就是每次遇到輕兒子,那麼這個輕兒子大小乘以2,那麼最多可以遇到多少個輕兒子呢,也就是\(\log_2{n}\)次,所以8號節點撤銷的次數不會超過\(\log{n}\),所以時間複雜度大概為\(O(n\log{n})\)。

一發就過,爽!!!

void solve() {
	int n;
	cin >> n;
	vector<int> a(n+1),b(n+1);
	for(int i=1;i<=n;i++){
		cin >> a[i];
	}
	for(int i=1;i<=n;i++){
		cin >> b[i];
	}
	vector<vector<int>> tr(n+1);
	for(int i=1;i<n;i++){
		int u,v;cin >> u >> v;
		tr[u].push_back(v);
		tr[v].push_back(u);
	}
	vector<int> cnt(n+1),ans(n+1);
	vector<int> siz(n+1),son(n+1),fa(n+1);
	int sum=0;
	auto go=[&](int u)->void{
		if(a[u]==0){
			cnt[a[u]]++;
			sum++;
		}else if(cnt[a[u]]>=0){
			cnt[a[u]]++;
			sum++;
		}else{
			cnt[a[u]]++;
			sum--;
		}
		if(b[u]==0){
			cnt[b[u]]++;
			sum++;
		}else if(cnt[b[u]]<=0){
			cnt[b[u]]--;
			sum++;
		}else{
			cnt[b[u]]--;
			sum--;
		}	
	};
	auto dfs1=[&](auto&& self,int u,int p)->void{
		siz[u]++;
		fa[u]=p;
		for(int v:tr[u]){
			if(v==p) continue;
			self(self,v,u);
			siz[u]+=siz[v];
			if(siz[v]>siz[son[u]]){
				son[u]=v;
			}
		}
	};
	auto del=[&](auto&&self,int u)->void{
		cnt[a[u]]=0;
		cnt[b[u]]=0;
		sum=0;
		for(int v:tr[u]){
			if(v==fa[u]) continue;
			self(self,v);
		}
	};
	auto add=[&](auto&&self,int u)->void{
		go(u);
		for(int v:tr[u]){
			if(v==fa[u]) continue;
			self(self,v);
		}
	};
	auto dfs2=[&](auto&& self,int u,int keep)->void{
		for(int v:tr[u]){
			if(v==fa[u]||v==son[u]) continue;
			self(self,v,0);
		}
		if(son[u]) self(self,son[u],1);
		//添加節點u的貢獻
		go(u);
		for(int v:tr[u]){
			if(v==fa[u]||v==son[u]) continue;
			add(add,v);
		}
		if(sum-cnt[0]<=cnt[0]) ans[u]=1;
		if(keep==0) del(del,u);
	};
	dfs1(dfs1,1,0);
	dfs2(dfs2,1,0);
	for(int i=1;i<=n;i++){
		cout << ans[i];
	}
	cout << endl;
}

法二:\(dfs\)序+莫隊

基於我們上面的分析,維護\(cnt\)數組和\(sum\)都是\(O(1)\)的。我們跑一遍\(dfs\),給每個節點編一個\(dfn\)序,這樣對於一棵子樹上的問題,在原數組上一定是一塊連續的子段\([L,R]\),那麼總共可以得到\(n\)個子段,我們要維護每個子段的信息,因此可以用離線+莫隊來做,時間複雜度為\(O(n\sqrt{n})\),對於\(\sum{n}\leq 2e5\),是可以接受的。

void solve() {
	int n;
	cin >> n;
	vector<int> a(n+1),b(n+1);
	for(int i=1;i<=n;i++){
		cin >> a[i];
	}
	for(int i=1;i<=n;i++){
		cin >> b[i];
	}
	vector<vector<int>> tr(n+1);
	for(int i=1;i<n;i++){
		int u,v;cin >> u >> v;
		tr[u].push_back(v);
		tr[v].push_back(u);
	}
	vector<int> ans(n+1),cnt(n+1);
	int sum=0;
	vector<int> dfn(n+1);
	int tot=0;
	vector<array<int,3>> seg(n+1);//seg[u]表示u管轄的範圍
	auto dfs1=[&](auto&& self,int u,int p)->void{
		dfn[u]=++tot;
		for(int v:tr[u]){
			if(v==p) continue;
			self(self,v,u);
		}
	};
	auto dfs2=[&](auto&& self,int u,int p)->void{
		seg[u]={dfn[u],dfn[u],u};
		for(int v:tr[u]){
			if(v==p) continue;
			self(self,v,u);
			seg[u][0]=min(seg[u][0],seg[v][0]);
			seg[u][1]=max(seg[u][1],seg[v][1]);
		}
	};
	dfs1(dfs1,1,0);
	dfs2(dfs2,1,0);
	vector<int> mp(n+1);
	for(int i=1;i<=n;i++){
		mp[dfn[i]]=i;//建立反向索引
	}
	int sq=sqrt(n);
	sort(seg.begin()+1,seg.end(),[&](array<int,3> x,array<int,3> y){
		auto[l1,r1,idx1]=x;
		auto[l2,r2,idx2]=y;
		if(l1/sq!=l2/sq) return l1/sq<l2/sq;
		else return r1<r2;
	});
	auto go=[&](int u)->void{
		if(a[u]==0){
			cnt[a[u]]++;
			sum++;
		}else if(cnt[a[u]]>=0){
			cnt[a[u]]++;
			sum++;
		}else{
			cnt[a[u]]++;
			sum--;
		}
		if(b[u]==0){
			cnt[b[u]]++;
			sum++;
		}else if(cnt[b[u]]<=0){
			cnt[b[u]]--;
			sum++;
		}else{
			cnt[b[u]]--;
			sum--;
		}	
	};
	int cl=1,cr=0;
	for(int i=1;i<=n;i++){
		auto[l,r,id]=seg[i];
		while(l<cl) go(mp[--cl]);
		while(r>cr) go(mp[++cr]);
		while(cl<l) go(mp[cl++]);
		while(cr>r) go(mp[cr--]);
		if(sum-cnt[0]<=cnt[0]) ans[id]=1;
	}
	for(int i=1;i<=n;i++){
		cout << ans[i];
	}
	cout << endl;
}

這裏的\(go(u)\)還有點問題。

修改了一下,改成了\(add\)和\(del\):

void solve() {
	int n;
	cin >> n;
	vector<int> a(n+1),b(n+1);
	for(int i=1;i<=n;i++){
		cin >> a[i];
	}
	for(int i=1;i<=n;i++){
		cin >> b[i];
	}
	vector<vector<int>> tr(n+1);
	for(int i=1;i<n;i++){
		int u,v;cin >> u >> v;
		tr[u].push_back(v);
		tr[v].push_back(u);
	}
	vector<int> ans(n+1),cnt(n+1);
	int sum=0;
	vector<int> dfn(n+1);
	int tot=0;
	vector<array<int,3>> seg(n+1);//seg[u]表示u管轄的範圍
	auto dfs1=[&](auto&& self,int u,int p)->void{
		dfn[u]=++tot;
		for(int v:tr[u]){
			if(v==p) continue;
			self(self,v,u);
		}
	};
	auto dfs2=[&](auto&& self,int u,int p)->void{
		seg[u]={dfn[u],dfn[u],u};
		for(int v:tr[u]){
			if(v==p) continue;
			self(self,v,u);
			seg[u][0]=min(seg[u][0],seg[v][0]);
			seg[u][1]=max(seg[u][1],seg[v][1]);
		}
	};
	dfs1(dfs1,1,0);
	dfs2(dfs2,1,0);
	vector<int> mp(n+1);
	for(int i=1;i<=n;i++){
		mp[dfn[i]]=i;//建立反向索引
	}
	int sq=sqrt(n);
	sort(seg.begin()+1,seg.end(),[&](array<int,3> x,array<int,3> y){
		auto[l1,r1,idx1]=x;
		auto[l2,r2,idx2]=y;
		if(l1/sq!=l2/sq) return l1/sq<l2/sq;
		else return r1<r2;
	});
	auto add=[&](int u)->void{
		if(a[u]==0){
			cnt[a[u]]++;
			sum++;
		}else if(cnt[a[u]]>=0){
			cnt[a[u]]++;
			sum++;
		}else{
			cnt[a[u]]++;
			sum--;
		}
		if(b[u]==0){
			cnt[b[u]]++;
			sum++;
		}else if(cnt[b[u]]<=0){
			cnt[b[u]]--;
			sum++;
		}else{
			cnt[b[u]]--;
			sum--;
		}
	};
	auto del=[&](int u)->void{
		if(a[u]==0){
			cnt[a[u]]--;
			sum--;
		}else if(cnt[a[u]]>0){
			cnt[a[u]]--;
			sum--;
		}else{
			cnt[a[u]]--;
			sum++;
		}
		if(b[u]==0){
			cnt[b[u]]--;
			sum--;
		}else if(cnt[b[u]]<0){
			cnt[b[u]]++;
			sum--;
		}else{
			cnt[b[u]]++;
			sum++;
		}
	};
	int cl=1,cr=0;
	for(int i=1;i<=n;i++){
		auto[l,r,id]=seg[i];
		while(l<cl) add(mp[--cl]);
		while(r>cr) add(mp[++cr]);
		while(cl<l) del(mp[cl++]);
		while(cr>r) del(mp[cr--]);
		if(sum-cnt[0]<=cnt[0]) ans[id]=1;
	}
	for(int i=1;i<=n;i++){
		cout << ans[i];
	}
	cout << endl;
}