题目大意
给出一个长度为nnn的序列AAA,再给出一个整数xxx,如果一个子序列满足以下的条件,则它是一个符合条件的子序列:
- 序列中的任意两个数的异或结果都大于等于xxx
求符合条件的子序列的个数,模998244353998244353998244353。
两个子序列不同,当且仅当它们取自于原序列中的位置中有至少一个位置不同。
数据范围
1≤n≤3×105,1≤Ai≤260,1≤x≤2601\leq n\leq 3\times 10^5,1\leq A_i\leq 2^{60},1\leq x\leq 2^{60}1≤n≤3×105,1≤Ai≤260,1≤x≤260
题解
令xxx的最高位为第ddd位。如果有两个数,它们在第ddd位以上的部分不同,那么显然它们异或之后一定有一位比第ddd位高且为111。也就是说,如果有两个数,它们在第ddd位以上的部分不同,那么它们不会有冲突。
我们把序列AAA中具有相同的第ddd位以上的部分的数放在一起,那么整个序列就被分成若干个同前缀序列。我在其中一个同前缀序列中选符合条件的数,并不影响我在另一个同前缀序列中选的数。所以我们可以先把每个同前缀序列中符合条件的子序列个数求出,然后求积,即可得出答案。
那么怎么求一个同前缀序列中符合条件的子序列个数呢?
因为分在一个同前缀序列中的数的第ddd位以上的部分相同,所以任意两个数的异或和的最高位一定不高于第ddd位。假如当前的子序列有两个数,那么这两个数的第ddd位一定满足一个是000。一个是111,否则异或和的第ddd位为000,也就小于xxx。如果我们要再加一个数,那么无论这个数的第ddd位是000还是111,它与先前的数异或之后一定会有一种情况使得第ddd位为000。所以,可以证明,在一个同前缀序列中符合条件的子序列的数的个数最多只有两个。
数的个数为000或111的子序列的个数很好求,分别为111和该同前缀序列的大小。若子序列的数的个数为222,那么对于每一个数,我们需要找到另一个数,使得这两个数的异或和大于等于xxx。
我们可以用字典树来维护。假设当前枚举的数为vvv,那么在字典树中:
- 如果xxx在当前这一位为111,则这两个数必须不同,我们就往vvv的这一位的另一个数的一边查找
- 如果xxx在当前这一位为000,则这两个数可以不同,我们就把vvv的这一位的另一个数的一边的数的和加到当前的答案上,在往vvv的这一位的数的一边查找
求出每一个同前缀序列的答案求出后,将每个同前缀序列的答案求积,即为最终的答案。
不要忘了最后减去序列为空的情况。
时间复杂度为O(60×n)O(60\times n)O(60×n)。
code
#include<bits/stdc++.h>
using namespace std;
int n,v1=0,tot,siz[15000005],ch[15000005][2];
long long x,tx,s=1,ans=1,now,mi[65],a[300005],v[300005];
long long mod=998244353;
vector<long long>w[300005];
void pt(long long v){int q=1,vq;for(int i=60;i>=0;i--){vq=(v>>i)&1;if(!ch[q][vq]) ch[q][vq]=++tot;q=ch[q][vq];++siz[q];}
}
void find(long long v){int q=1,vq;for(int i=60;i>=0;i--){vq=(v>>i)&1;if((x>>i)&1){if(!ch[q][vq^1]) return;q=ch[q][vq^1];}else{if(ch[q][vq^1]) now=(now+siz[ch[q][vq^1]])%mod;if(!ch[q][vq]) return;q=ch[q][vq];}}now=(now+siz[q])%mod;
}
void cl(int x){if(ch[x][0]) cl(ch[x][0]);if(ch[x][1]) cl(ch[x][1]);ch[x][0]=ch[x][1]=0;siz[x]=0;
}
int main()
{scanf("%d%lld",&n,&x);mi[0]=1;for(int i=1;i<=60;i++) mi[i]=mi[i-1]*2;tx=x;while(x){x>>=1;s<<=1;}s=mi[60]-s;x=tx;for(int i=1;i<=n;i++){scanf("%lld",&a[i]);}sort(a+1,a+n+1);for(int i=1;i<=n;i++){if(v1==0||(a[i]&s)!=v[v1]){v[++v1]=(a[i]&s);}w[v1].push_back((a[i]|s)^s);}for(int i=1;i<=v1;i++){int l=w[i].size();now=l+1;tot=1;for(int j=0;j<l;j++){find(w[i][j]);pt(w[i][j]);}ans=ans*now%mod;cl(1);}ans=(ans+mod-1)%mod;printf("%lld",ans);return 0;
}