你如果能确定一个问题答案一定是一个多项式形式,那么你可以先暴力求出来几个点的解,带入,把这个多项式的系数求出来,接下来给出自变量的话,你直接往这个式子里带入就能得到答案了。
具体的原理就是oiwiki上的这个过程
这里需要注意的是,对于一个最高次为k的多项式,至少需要k+1个不同的点才能确定全部系数。
求系数的过程暴力是 O ( n 2 ) O(n^2) O(n2)的,这要求我们多项式次数不能太大。不过对于连续的数据点( ( l , f ( l ) ) , ( l + 1 , f ( l + 1 ) ) . . . (l,f(l)),(l+1,f(l+1))... (l,f(l)),(l+1,f(l+1))...),可以 O ( n ) O(n) O(n)求系数
先来一个模板题,没啥好说的,把给出的点传进去,返回多项式系数,然后把自变量 x x x带入多项式就行了,注意是 x k x^k xk和 f [ k ] f[k] f[k]相乘
std::vector<int> lagrange_interpolation(const std::vector<int> &x,
const std::vector<int> &y,int MOD) {
const int n = x.size();
std::vector<int> M(n + 1), xx(n), f(n);
M[0] = 1;
// 求出 M(x) = prod_(i=0..n-1)(x - x_i)
for (int i = 0; i < n; ++i) {
for (int j = i; j >= 0; --j) {
M[j + 1] = (M[j] + M[j + 1]) % MOD;
M[j] = (LL)M[j] * (MOD - x[i]) % MOD;
}
}
// 求出 xx_i = M'(x_i) = (M / (x - x_i)) mod (x - x_i) 一定非零
for (int i = n - 1; i >= 0; --i) {
for (int j = 0; j < n; ++j) {
xx[j] = ((LL)xx[j] * x[j] + (LL)M[i + 1] * (i + 1)) % MOD;
}
}
// 组合出 f(x) = sum_(i=0..n-1)(y_i / M'(x_i))(M / (x - x_i))
for (int i = 0; i < n; ++i) {
LL t = (LL)y[i] * inv(xx[i],MOD) % MOD, k = M[n];
for (int j = n - 1; j >= 0; --j) {
f[j] = (f[j] + k * t) % MOD;
k = (M[j] + k * x[i]) % MOD;
}
}
return f;
}
void solve(){
cin>>n>>k;
vi x(n),y(n);
rep(i,0,n-1)cin>>x[i]>>y[i];
vi f=lagrange_interpolation(x,y,M2);
int ans=0;
rep(i,0,n-1){
ans+=f[i]*power(k,i,M2)%M2;
ans%=M2;
}
cout<<ans;
}
接下来是一个稍微复杂一点的题。 s ( k , n ) s(k,n) s(k,n)表示一个长度不超过 n n n的数组,元素都在 [ 1 , k ] [1,k] [1,k]之间,所有元素乘起来是 k k k,的方案数。给定 k k k求 n ∈ [ 1 , k ] n∈[1,k] n∈[1,k]的答案
首先固定数组长度为 n n n的话,这个问题实际上就是个球盒模型,对于 k k k的每个质因子 f a c fac fac,如果指数为 p p p的话,我们要做的就是把这 p p p个 f a c fac fac安排到数组中的 n n n个位置上,求方案数。
这就是一个球相同,盒子不同,可以空盒的球盒问题, C ( n + r − 1 , r − 1 ) , r 为盒数 , n 为球数 C(n+r-1,r-1),r为盒数,n为球数 C(n+r−1,r−1),r为盒数,n为球数。然后不同质因子之间互不干扰,根据乘法原理方案应该乘起来
对于每一个数组长度,答案显然是个·多项式, s ( k , n ) s(k,n) s(k,n)就是把每个长度的多项式解加起来,还是个多项式。并且注意到 C ( n + r − 1 , r − 1 ) = C ( n + r − 1 , n ) , n 为球数 C(n+r-1,r-1)=C(n+r-1,n),n为球数 C(n+r−1,r−1)=C(n+r−1,n),n为球数,那么 n n n也就是 k k k的质因子的最大指数,那么 n n n最大也只是 l o g ( k ) log(k) log(k)的,不超过 17 17 17,所以我们可以知道这个多项式不会超过18个系数,那么我们传入 18 18 18个不同数据点就能得到这个多项式了。
这18个数据点就暴力求就完了,反正复杂度很低,可以考虑把 s ( k , 1 ) , s ( k , 2 ) . . . s ( k , 18 ) s(k,1),s(k,2)...s(k,18) s(k,1),s(k,2)...s(k,18)求出来。这里注意 s ( k , n ) s(k,n) s(k,n)也是包含 s ( k , 1 ) . . . s ( k , n − 1 ) s(k,1)...s(k,n-1) s(k,1)...s(k,n−1)的,可以用一个前缀和累加
求出来之后这就是一个关于 k k k的 17 17 17次多项式,把题目给的 k k k往里带就行了
const int mod = 998244353;
typedef Modint<mod> mint;
template <const int mod> struct comb {
vector<int> f, g, v;
comb(int n) : f(n + 1), g(n + 1), v(n + 1) {
f[0] = g[0] = v[0] = 1;
v[1] = 1;
for (int i = 2; i <= n; i++) {
v[i] = (1LL * (mod - mod / i) * v[mod % i]) % mod;
}
for (int i = 1; i <= n; i++) {
f[i] = (1LL * i * f[i - 1]) % mod;
g[i] = (1LL * g[i - 1] * v[i]) % mod;
}
}
ll operator()(int n, int m) {
if (m > n || m < 0 || n < 0) return 0;
int ans = (1LL * f[n] * g[m]) % mod;
ans = (1LL * ans * g[n - m]) % mod;
return ans;
}
};
comb<mod> C(1000000);
std::vector<int> lagrange_interpolation(const std::vector<int> &x,
const std::vector<int> &y,int MOD) {
const int n = x.size();
std::vector<int> M(n + 1), xx(n), f(n);
M[0] = 1;
// 求出 M(x) = prod_(i=0..n-1)(x - x_i)
for (int i = 0; i < n; ++i) {
for (int j = i; j >= 0; --j) {
M[j + 1] = (M[j] + M[j + 1]) % MOD;
M[j] = (LL)M[j] * (MOD - x[i]) % MOD;
}
}
// 求出 xx_i = M'(x_i) = (M / (x - x_i)) mod (x - x_i) 一定非零
for (int i = n - 1; i >= 0; --i) {
for (int j = 0; j < n; ++j) {
xx[j] = ((LL)xx[j] * x[j] + (LL)M[i + 1] * (i + 1)) % MOD;
}
}
// 组合出 f(x) = sum_(i=0..n-1)(y_i / M'(x_i))(M / (x - x_i))
for (int i = 0; i < n; ++i) {
LL t = (LL)y[i] * inv(xx[i],MOD) % MOD, k = M[n];
for (int j = n - 1; j >= 0; --j) {
f[j] = (f[j] + k * t) % MOD;
k = (M[j] + k * x[i]) % MOD;
}
}
return f;
}
void solve(){
cin>>k>>n;
rep(i,1,k){
vi x,y;
map<int,int>mp;
int t=i;
rep(j,2,sqrt(t)){
while(t%j==0){
t/=j;
mp[j]++;
}
}
if(t!=1)mp[t]++;
int sum=0;
rep(j,1,18){
int cur=1;
for(auto &[key,val]:mp){
cur=cur*C(j-1+val,j-1)%M2;
}
sum+=cur;
sum%=M2;
x.push_back(j);
y.push_back(sum);
}
vi f=lagrange_interpolation(x,y,M2);
int ans=0;
rep(i,0,17){
ans+=f[i]*power(n,i,M2)%M2;
ans%=M2;
}
cout<<ans<<' ';
}
cout<<'\n';
}