2022 hdu 多校 8 1002

线段树 + 单调栈优化 dp


写了 40mins 写挂了 pushdown ,该加训了。

题意

定义一个序列是“好的”当且仅当它满足序列中的最大值所在下标为奇数,最小值所在下标为偶数。
定义一种序列的划分是“吊的”当且仅当它恰好将序列划分成了若干个子串,子串原序拼接能得到原序列且每个子串都是“好的”。
序列和划分后的子串下标都从 11 开始。
求长为 nn ,各个元素互不相同的序列 aa 的“吊的”划分方案数。
其中 1n3105,1ai1091 \leq n \leq 3 \cdot 10^5, 1 \leq a_i \leq 10^9

思路

最大最小值套路地想到单调栈维护 maxposmaxposminposminpos
考虑倒着 dpdpdpidp_i表示序列 [i,n][i, n] 一段的“吊的”划分方案数,有:

{dp[i]=j=indp[j+1][maxpos[i,j] mod 2=i mod 2 && minpos[i,j] mod 2i mod 2]dp[n+1]=1\begin{cases} dp[i] = \sum\limits_{j = i}^{n} dp[j + 1] \cdot [maxpos[i, j]\ mod\ 2 = i\ mod\ 2\ \And\And \ minpos[i, j]\ mod\ 2 \not = i\ mod\ 2]\\ dp[n + 1] = 1\\ \end{cases}

倒着 dpdp 同时维护两个单调栈 ma,mima, mi 记录 ii 往后第一个比它大 / 小的数的位置,可以得到固定 ii 为左端点时,右端点 jj 选取不同位置形成的区间中的 maxposmaxposminposminpos 的奇偶性。
开数组 sum[0/1][0/1]sum[0/1][0/1] 记录当前所有右端点选取情况中 maxposmaxpos 为偶 / 奇, minposminpos 为偶 / 奇时可转移的 dpdp 值的和,转移时使 dp[i]+=sum[i&1][(i&1)1]dp[i] += sum[i \And 1][(i \And 1) \bigoplus 1] 即可。
同时开两个线段树 tr[0/1]tr[0/1] 分别记录区间内 maxpos/minposmaxpos / minpos 为偶 / 奇的位置对应的 dpdp 值的和,弹出 / 加入单调栈时维护对应 sum[0/1][0/1]sum[0/1][0/1] 的值即可。
总复杂度为 O(nlogn)O(n \cdot logn)

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/*
Author: Cupids_Bow
Time: 2022-08-11 15:40:59
*/
#include<bits/stdc++.h>
using namespace std;

using ll=long long;
using pll=pair<ll,ll>;

constexpr ll mod=998244353;
constexpr int N=3e5+5;
int _=1;

vector<int> ma,mi;
int n;
int a[N];
ll sum[2][2];
ll dp[N];
struct tree{
int l[N<<2];
int r[N<<2];
ll lz[N<<2];
ll sum[N<<2][2];
void pushup(int x){
for(int i=0;i<=1;i++) sum[x][i]=(sum[x<<1][i]+sum[x<<1|1][i])%mod;
}
void build(int x,int L,int R){
l[x]=L;
r[x]=R;
lz[x]=-1;
for(int i=0;i<=1;i++) sum[x][i]=0;
if(L==R) return;
int mid=L+R>>1;
build(x<<1,L,mid);
build(x<<1|1,mid+1,R);
return;
}
void pushdown(int x){
if(lz[x]!=-1){
int col=lz[x];
sum[x<<1][col]=(dp[l[x<<1]+1]+mod-dp[r[x<<1]+2])%mod;
sum[x<<1][col^1]=0;
lz[x<<1]=col;
sum[x<<1|1][col]=(dp[l[x<<1|1]+1]+mod-dp[r[x<<1|1]+2])%mod;
sum[x<<1|1][col^1]=0;
lz[x<<1|1]=col;
lz[x]=-1;
}
return;
}
void change(int x,int L,int R,int col){
if(l[x]>=L&&r[x]<=R){
sum[x][col]=(dp[l[x]+1]+mod-dp[r[x]+2])%mod;
sum[x][col^1]=0;
lz[x]=col;
return;
}
pushdown(x);
if(r[x<<1]>=L) change(x<<1,L,R,col);
if(l[x<<1|1]<=R) change(x<<1|1,L,R,col);
pushup(x);
return;
}
pll search(int x,int L,int R){
if(l[x]>=L&&r[x]<=R) return {sum[x][0],sum[x][1]};
pushdown(x);
pll res={0,0};
if(r[x<<1]>=L){
pll node=search(x<<1,L,R);
res.first=(res.first+node.first)%mod;
res.second=(res.second+node.second)%mod;
}
if(l[x<<1|1]<=R){
pll node=search(x<<1|1,L,R);
res.first=(res.first+node.first)%mod;
res.second=(res.second+node.second)%mod;
}
return res;
}
}tr[2];

void work(){
ma.clear();
mi.clear();
memset(sum,0,sizeof(sum));
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=0;i<=n+2;i++) dp[i]=0;
dp[n+1]=1;
ma.push_back(n+1);
mi.push_back(n+1);
tr[0].build(1,1,n+1);
tr[1].build(1,1,n+1);
for(int i=n;i>=1;i--){
while(ma.back()!=n+1&&a[ma.back()]<a[i]){
int pos=ma.back();
int col=(pos&1);
ma.pop_back();
pll res=tr[1].search(1,pos,ma.back()-1);
sum[col][0]=(sum[col][0]+mod-res.first)%mod;
sum[col][1]=(sum[col][1]+mod-res.second)%mod;
}
pll res=tr[1].search(1,i,ma.back()-1);
int col=(i&1);
sum[col][0]=(sum[col][0]+res.first)%mod;
sum[col][1]=(sum[col][1]+res.second)%mod;
tr[0].change(1,i,ma.back()-1,(i&1));
ma.push_back(i);
while(mi.back()!=n+1&&a[mi.back()]>a[i]){
int pos=mi.back();
int col=(pos&1);
mi.pop_back();
pll res=tr[0].search(1,pos,mi.back()-1);
sum[0][col]=(sum[0][col]+mod-res.first)%mod;
sum[1][col]=(sum[1][col]+mod-res.second)%mod;
}
res=tr[0].search(1,i,mi.back()-1);
col=(i&1);
sum[0][col]=(sum[0][col]+res.first)%mod;
sum[1][col]=(sum[1][col]+res.second)%mod;
tr[1].change(1,i,mi.back()-1,(i&1));
mi.push_back(i);
dp[i]=sum[i&1][(i&1)^1];
dp[i]=(dp[i]+dp[i+1])%mod;
tr[0].change(1,i,i,(i&1));
tr[1].change(1,i,i,(i&1));
}
cout<<(dp[1]+mod-dp[2])%mod<<"\n";
}

int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);

cin>>_;
while(_--){
work();
}

return true&&false;
}