每种颜色分开 dp 计算贡献
题意
给定长度为 n 的序列 a ,求满足以下条件的序列 a 的个数。
- ∀i∈[1,n] ,有 0≤ai≤m 。
- ∀j∈[1,q] ,有 max(alj,alj+1,...,arj)=xj 。
其中 1≤n≤2⋅105,1≤m<998244353,1≤q≤2⋅105,1≤li,ri≤n,1≤xi≤m 。
思路
构造序列 b ,初始时 bi=m ,对于每个操作 j ,令 bi=min(bi,xj) (lj≤i≤rj) 。
操作完之后条件即转变为:
- ∀i∈[1,n] ,有 0≤ai≤bi 。
- ∀j∈[1,q] ,有 max(alj,alj+1,...,arj)=xj 。
b 序列可以 O(n⋅logn) 求出。
现在考虑怎么满足每个条件 (lj,rj,xj) ,注意到对于一个条件 (lj,rj,xj) ,只有满足 bi=xj (lj≤i≤rj) 的位置 i 能达到条件。
因为若有位置 i 使得 bi>xj ,则 i 一定不在 [lj,rj] 中,若有位置 i 使得 bi<xj ,则 bi 取不到 xj 。
所以可以对于每一个 x ,将所有 xj=x 的条件与 bi=x 的位置独立出来计算贡献,最后再乘起来。
那么对于单个 x ,现在问题转变为:
- ∀i∈[1,n] ,有 0≤ai≤m 。
- ∀j∈[1,q] ,有 max(alj,alj+1,...,arj)=m 。
- ∪j=1q[lj,rj]=[1,n]
那么可以考虑一个 DP 解决。
dpi 表示填充了前 i 个位置,满足了所有 rj≤i 的条件,且第 i 个位置值为 m 的序列的个数,在 i 位置统一处理 rj=i 的区间,转移显然:
{dpi=∑j=0i−1dpj⋅mi−j−1 (对 ∀k 满足 i=rk 有 lk≤j)dp0=1
预处理 b 序列复杂度为 O(n⋅logn) ,DP 部分总复杂度为 O(n) 。\
类似套路的题 CF1327F,P4229。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
#include<bits/stdc++.h> using namespace std;
using ll=long long; using pii=pair<int,int>;
constexpr int mod=998244353; constexpr int N=2e5+5;
int main(){ ios::sync_with_stdio(0); cin.tie(0);cout.tie(0);
int n,m,q; cin>>n>>m>>q;
vector<int> a(n,m+1); map<int,vector<pii>> qry,mp; for(int i=0,l,r,x;i<q;i++){ cin>>l>>r>>x; l--,r--; qry[x+1].push_back({l,r}); mp[l].push_back({x,1}); mp[r+1].push_back({x,-1}); }
multiset<int> st; map<int,vector<int>> pos; int cnt=0; for(int i=0;i<n;i++){ for(auto [x,flag]:mp[i]){ if(flag==1) st.insert(x); else st.erase(st.find(x)); } if(st.size()) a[i]=(*st.begin())+1; else cnt++; pos[a[i]].push_back(i); }
auto power=[&](int a,int b=mod-2){ int res=1; while(b){ if(b&1) res=1ll*res*a%mod; a=1ll*a*a%mod; b>>=1; } return res; };
auto calc=[&](vector<int> pos,vector<pii> qry,int val){ for(auto &[l,r]:qry){ l=lower_bound(pos.begin(),pos.end(),l)-pos.begin(); r=upper_bound(pos.begin(),pos.end(),r)-pos.begin()-1; if(l>r) return 0; } for(int i=0;i<pos.size();i++) pos[i]=i; vector<int> dp(pos.size(),0),maxpos(pos.size(),-1); for(auto [l,r]:qry) maxpos[r]=max(maxpos[r],l); int l=-1; int sum=1; for(int i=0;i<pos.size();i++){ dp[i]=sum; while(l<maxpos[i]){ if(l==-1) sum=(sum+mod-power(val-1,i-l-1))%mod; else sum=(sum+mod-1ll*power(val-1,i-l-1)*dp[l]%mod)%mod; l++; } sum=1ll*sum*(val-1)%mod; sum=(sum+dp[i])%mod; } return sum; };
int ans=1; for(auto [x,vec]:qry) ans=1ll*ans*calc(pos[x],qry[x],x)%mod; ans=1ll*ans*power(m+1,cnt)%mod;
cout<<ans;
return true&&false; }
|