LOJ6570 毛毛虫计数
tags:生成函数,多项式
题意
hsezoi 巨佬 olinr 喜欢 van 毛毛虫,他定义毛毛虫是一棵树,满足树上存在一条树链,使得树上所有点到这条树链的距离最多为 \(1\)。给定 \(n\) \((n\le10^5)\) 。现在请你求出 \(n\) 个点、有标号的毛毛虫的数量。答案对 \(998244353\) 取模。
题解
构造生成函数
对于毛毛虫直径中间的一个节点,大小为 i 总共有 i 种放法,指数型生成函数是
\[ A(x)=\sum_{i=1}^\infty\frac{ix^i}{i!} \]
对于与直径两端点相连的一个节点,强制至少挂一个节点,指数型生成函数是
\[ B(x)=\sum_{i=2}^\infty\frac{ix^i}{i!} \]
然后结果就是 \((A(x)^0+A(x)^1+\cdots)B(x)^2=\frac {B(x)^2}{1-A(x)}\)
输出:
\[ \frac{n!}2[x^n]\frac {B(x)^2}{1-A(x)} \]
注意要特判 n = 2
顺便放下我的多项式板子,需要的时候可以拉
#include<cstdio>
#include<vector>
//#define debug(...) fprintf(stderr,__VA_ARGS__)
#define debug(...) ((void)0)
typedef std::vector<int> poly;
const int P=998244353;
int fpow(int a,int b){int res=1;for(;b;b>>=1,a=1ll*a*a%P)if(b&1)res=1ll*res*a%P;return res;}
void pt(const poly&a){for(int i=0;i<(int)a.size();++i)debug("%d ",a[i]);debug("\n");}
int getlim(int n){int x=1;while(x<=n)x<<=1;return x;}
void ntt(poly&a,int g,int lim){a.resize(lim);for(int i=0,j=0;i<lim;++i){if(i<j)std::swap(a[i],a[j]);for(int k=lim>>1;(j^=k)<k;k>>=1);}poly w(lim>>1);w[0]=1;for(int i=1;i<lim;i<<=1){for(int j=1,wn=fpow(g,(P-1)/(i<<1));j<i;++j)w[j]=1ll*w[j-1]*wn%P;for(int j=0;j<lim;j+=i<<1)for(int k=0;k<i;++k){int x=a[j+k],y=1ll*a[i+j+k]*w[k]%P;a[j+k]=(x+y)%P,a[i+j+k]=(x-y+P)%P;}}if(g==332748118)for(int i=0,I=fpow(lim,P-2);i<(int)a.size();++i)a[i]=1ll*a[i]*I%P;
}
poly pmul(poly a,poly b){int need=(int)a.size()+b.size()-1,lim=getlim(need);ntt(a,3,lim),ntt(b,3,lim);for(int i=0;i<lim;++i)a[i]=1ll*a[i]*b[i]%P;ntt(a,332748118,lim);return a.resize(need),a;
}
poly padd(poly a,poly b){if(a.size()<b.size()){for(int i=0;i<(int)a.size();++i)(b[i]+=a[i])%=P;return b;}else{for(int i=0;i<(int)b.size();++i)(a[i]+=b[i])%=P; return a;}
}
poly pinv(const poly&a,int n=-1){if(n==-1)n=a.size();if(n==1)return poly(1,fpow(a[0],P-2));poly b=pinv(a,(n+1)>>1),tmp=poly(a.begin(),a.begin()+n);int lim=getlim(n*2-2);ntt(b,3,lim),ntt(tmp,3,lim);for(int i=0;i<lim;++i)b[i]=(2-1ll*b[i]*tmp[i]%P+P)%P*b[i]%P;ntt(b,332748118,lim);return b.resize(n),b;
}
poly pdao(const poly&a){poly b((int)a.size()-1);for(int i=1;i<(int)a.size();++i)b[i-1]=1ll*a[i]*i%P;return b;
}
poly pji(const poly&a){poly b((int)a.size()+1);for(int i=0;i<(int)a.size();++i)b[i+1]=1ll*a[i]*fpow(i+1,P-2)%P;return b;
}
poly pln(const poly&a){poly b(pmul(pdao(a),pinv(a)));b.resize((int)a.size()-1);return pji(b);
}
poly pexp(const poly&a,int n=-1){if(n==-1)n=a.size();if(n==1)return poly(1,1);poly b=pexp(a,(n+1)>>1),c(b);c.resize(n),c=pln(c),--c[0];for(int i=0;i<n;++i)c[i]=(a[i]-c[i]+P)%P;poly d(pmul(b,c));return d.resize(n),d;
}
const int N=100005;
int n,fac[N],inv[N];
int main(){fac[0]=fac[1]=inv[0]=inv[1]=1;for(int i=2;i<N;++i)fac[i]=1ll*fac[i-1]*i%P,inv[i]=1ll*(P-P/i)*inv[P%i]%P;for(int i=2;i<N;++i)inv[i]=1ll*inv[i]*inv[i-1]%P;scanf("%d",&n);if(n<=2)return puts("1"),0;poly A(n+1),B(n+1);A[0]=1;for(int i=1;i<=n;++i)A[i]=(P-1ll*i*inv[i]%P)%P;for(int i=2;i<=n;++i)B[i]=1ll*i*inv[i]%P;A=pinv(A),B=pmul(B,B),B.resize(n+1),A=pmul(A,B),A.resize(n+1);printf("%lld\n",(n+1ll*A[n]*fac[n]%P*((P+1)>>1)%P)%P);return 0;
}