Loading [MathJax]/jax/output/CommonHTML/jax.js
My Algorithm : kopricky アルゴリズムライブラリ

kopricky アルゴリズムライブラリ

Polynomial Interpolation

コードについての説明(個人的メモ)

体上の係数が未知の n1 次多項式 f(x)=n1i=0cixi に対して n 個の値の組 (u0,v0=f(u0)),,(un1,vn1=f(un1)) が分かっているとする. ただし, ui はすべて異なるものとする.
このとき f(x) のすべての係数 ci(0i<n) を求めるアルゴリズムである.
以下の実装は Z/pZ (p は素数) の場合の実装で特に素数 p について p=k2l+1(k>0,n+12l) を満たすこと(数論変換に乗ること)を仮定している.
こちらの参考資料が分かりやすく, それを基に実装を行った.

(関数)
polynomial_interpolation(u,v): 未知の n1 次多項式 fn 点の値 v0=f(u0),,vn1=f(un1) を与える(u は全て異なるものとする). 多項式 f(x)=n1i=0cixi の係数 c を配列として返す.

時間計算量: O(nlog2n)

コード

  1. #define MOD 998244353
  2. #define root 3
  3.  
  4. unsigned int add(const unsigned int x, const unsigned int y)
  5. {
  6. return (x + y < MOD) ? x + y : x + y - MOD;
  7. }
  8.  
  9. unsigned int sub(const unsigned int x, const unsigned int y)
  10. {
  11. return (x >= y) ? (x - y) : (MOD - y + x);
  12. }
  13.  
  14. unsigned int mul(const unsigned int x, const unsigned int y)
  15. {
  16. return (unsigned long long)x * y % MOD;
  17. }
  18.  
  19. unsigned int mod_pow(unsigned int x, unsigned int n)
  20. {
  21. unsigned int res = 1;
  22. while(n > 0) {
  23. if(n & 1) { res = mul(res, x); }
  24. x = mul(x, x);
  25. n >>= 1;
  26. }
  27. return res;
  28. }
  29.  
  30. unsigned int inverse(const unsigned int x)
  31. {
  32. return mod_pow(x, MOD - 2);
  33. }
  34.  
  35. void ntt(vector<int>& a, const bool rev = false)
  36. {
  37. unsigned int i, j, k, l, p, q, r, s;
  38. const unsigned int size = a.size();
  39. if(size == 1) return;
  40. vector<int> b(size);
  41. r = rev ? (MOD - 1 - (MOD - 1) / size) : (MOD - 1) / size;
  42. s = mod_pow(root, r);
  43. vector<unsigned int> kp(size / 2 + 1, 1);
  44. for(i = 0; i < size / 2; ++i) kp[i + 1] = mul(kp[i], s);
  45. for(i = 1, l = size / 2; i < size; i <<= 1, l >>= 1){
  46. for(j = 0, r = 0; j < l; ++j, r += i){
  47. for(k = 0, s = kp[i * j]; k < i; ++k){
  48. p = a[k + r], q = a[k + r + size / 2];
  49. b[k + 2 * r] = add(p, q);
  50. b[k + 2 * r + i] = mul(sub(p, q), s);
  51. }
  52. }
  53. swap(a, b);
  54. }
  55. if(rev){
  56. s = inverse(size);
  57. for(i = 0; i < size; i++){ a[i] = mul(a[i], s); }
  58. }
  59. }
  60.  
  61. vector<int> convolute(const vector<int>& a, const vector<int>& b, int asize, int bsize, int _size)
  62. {
  63. if((long long)asize * min(bsize, _size) < 128LL){
  64. vector<int> A(_size, 0);
  65. for(int i = 0; i < asize; ++i){
  66. for(int j = 0; j < min(bsize, _size - i); ++j){
  67. A[i+j] = add(A[i+j], mul(a[i], b[j]));
  68. }
  69. }
  70. return A;
  71. }
  72. const int size = asize + bsize - 1;
  73. int t = 1;
  74. while(t < size){ t <<= 1; }
  75. vector<int> A(t, 0), B(t, 0);
  76. for(int i = 0; i < asize; i++){ A[i] = (a[i] < MOD) ? a[i] : (a[i] % MOD); }
  77. for(int i = 0; i < bsize; i++){ B[i] = (b[i] < MOD) ? b[i] : (b[i] % MOD); }
  78. ntt(A), ntt(B);
  79. for(int i = 0; i < t; i++) { A[i] = mul(A[i], B[i]); }
  80. ntt(A, true);
  81. A.resize(_size);
  82. return A;
  83. }
  84.  
  85. vector<int> polynomial_inverse(const vector<int>& a, int r){
  86. vector<int> h = {(int)inverse(a[0])};
  87. int t = 1;
  88. for(int i = 0; t < r; ++i){
  89. t <<= 1;
  90. vector<int> res = convolute(a, convolute(h, h, t / 2, t / 2, t), min((int)a.size(), t), t, t);
  91. for(int j = 0; j < t; ++j){
  92. res[j] = MOD - res[j];
  93. if(j < t / 2) res[j] = add(res[j], mul(2, h[j]));
  94. }
  95. swap(h, res);
  96. }
  97. h.resize(r);
  98. return h;
  99. }
  100.  
  101. pair<vector<int>, vector<int> > polynomial_division(const vector<int>& a, const vector<int>& b)
  102. {
  103. const int n = a.size() - 1, m = b.size() - 1;
  104. assert(b[m] != 0);
  105. if(n < m) return {vector<int>(n - m + 1, 0), a};
  106. vector<int> reva(n + 1), revb(m + 1);
  107. for(int i = 0; i <= n; ++i) reva[n - i] = a[i];
  108. for(int i = 0; i <= m; ++i) revb[m - i] = b[i];
  109. vector<int> inv_revb = polynomial_inverse(revb, n - m + 1);
  110. vector<int> res = convolute(reva, inv_revb, n - m + 1, n - m + 1, n - m + 1);
  111. vector<int> q(n - m + 1), r;
  112. for(int i = 0; i < n - m + 1; ++i) q[i] = res[n - m - i];
  113. vector<int> qb = convolute(q, b, n - m + 1, m + 1, n + 1);
  114. bool first = false;
  115. for(int i = n; i >= 0; --i){
  116. const int val = sub(a[i], qb[i]);
  117. if(!first && val > 0){
  118. first = true, r.resize(i + 1);
  119. }
  120. if(first) r[i] = val;
  121. }
  122. return {q, r};
  123. }
  124.  
  125. int func(const vector<int>& f, const int u){
  126. int res = 0, mult = 1;
  127. for(int i = 0; i < (int)f.size(); ++i){
  128. res = add(res, mul(f[i], mult));
  129. mult = mul(mult, u);
  130. }
  131. return res;
  132. }
  133.  
  134. int pre_computation(const vector<int>& u, vector<vector<vector<int> > >& p)
  135. {
  136. const int m = (int)u.size();
  137. int sz = 1, t = 1;
  138. while(t < m) ++sz, t <<= 1;
  139. const int res = t;
  140. p.resize(sz), p[sz - 1].resize(t);
  141. for(int j = 0; j < m; ++j){
  142. p[sz - 1][j] = {(int)sub(0, u[j]), 1};
  143. }
  144. for(int j = m; j < t; ++j){
  145. p[sz - 1][j] = {1};
  146. }
  147. t /= 2;
  148. for(int i = sz - 2; i >= 0; --i){
  149. p[i].resize(t);
  150. for(int j = 0; j < t; ++j){
  151. const int x = (int)p[i+1][2*j].size(), y = (int)p[i+1][2*j+1].size();
  152. if(y > 1) p[i][j] = convolute(p[i+1][2*j], p[i+1][2*j+1], x, y, x + y - 1);
  153. else p[i][j] = p[i+1][2*j];
  154. }
  155. t /= 2;
  156. }
  157. return res;
  158. }
  159.  
  160. void recursive_multipoint_evaluation(const vector<int>& f, const vector<int>& u,
  161. const vector<vector<vector<int> > >& p, vector<int>& ans, const int usize, const int size,
  162. const int depth, const int num)
  163. {
  164. if(usize <= 32){
  165. for(int i = 0; i < usize; ++i){
  166. const int ad = func(f, u[ans.size()]);
  167. ans.push_back(ad);
  168. }
  169. return;
  170. }
  171. const int lsize = min(usize, size / 2), rsize = max(usize - size / 2, 0);
  172. auto r0 = polynomial_division(f, p[depth + 1][2 * num]);
  173. recursive_multipoint_evaluation(r0.second, u, p, ans, lsize, size / 2, depth + 1, 2 * num);
  174. if(rsize == 0) return;
  175. auto r1 = polynomial_division(f, p[depth + 1][2 * num + 1]);
  176. recursive_multipoint_evaluation(r1.second, u, p, ans, rsize, size / 2, depth + 1, 2 * num + 1);
  177. }
  178.  
  179. vector<int> multipoint_evaluation
  180. (const vector<int>& f, const vector<int>& u, const vector<vector<vector<int> > >& p, const int al)
  181. {
  182. vector<int> ans;
  183. return recursive_multipoint_evaluation(f, u, p, ans, (int)u.size(), al, 0, 0), ans;
  184. }
  185.  
  186. vector<int> differentiate(const vector<int>& a)
  187. {
  188. const int n = (int)a.size();
  189. vector<int> res(n - 1);
  190. for(int i = 0; i < n - 1; ++i){
  191. res[i] = mul(a[i + 1], i + 1);
  192. }
  193. return res;
  194. }
  195.  
  196. vector<int> _pre_computation
  197. (const vector<int>& u, const vector<int>& v, vector<vector<vector<int> > >& p)
  198. {
  199. const int al = pre_computation(u, p);
  200. const vector<int> g = differentiate(p[0][0]);
  201. return multipoint_evaluation(g, u, p, al);
  202. }
  203.  
  204. void recursive_polynomial_interpolation
  205. (const vector<int>& u, const vector<int>& v, vector<int>& res,
  206. const vector<vector<vector<int> > >& p, const int usize, const int size,
  207. const int depth, const int num, const int id)
  208. {
  209. if(usize == 1){
  210. res = {v[id]};
  211. return;
  212. }
  213. const int lsize = min(usize, size / 2), rsize = max(usize - size / 2, 0);
  214. if(rsize == 0){
  215. vector<int> r0;
  216. recursive_polynomial_interpolation(u, v, r0, p, lsize, size / 2, depth + 1, 2 * num, id);
  217. res.resize(lsize);
  218. for(int i = 0; i < lsize; ++i) res[i] = r0[i];
  219. }else{
  220. vector<int> r0, r1;
  221. recursive_polynomial_interpolation(u, v, r0, p, lsize, size / 2, depth + 1, 2 * num, id);
  222. recursive_polynomial_interpolation(u, v, r1, p, rsize, size / 2, depth + 1, 2 * num + 1, id + size / 2);
  223. const vector<int> res1 = convolute(r0, p[depth + 1][2 * num + 1], lsize, rsize + 1, lsize + rsize);
  224. const vector<int> res2 = convolute(r1, p[depth + 1][2 * num], rsize, lsize + 1, lsize + rsize);
  225. res.resize(lsize + rsize);
  226. for(int i = 0; i < lsize + rsize; ++i) res[i] = add(res1[i], res2[i]);
  227. }
  228. }
  229.  
  230. vector<int> polynomial_interpolation(const vector<int>& u, const vector<int>& v)
  231. {
  232. const int n = (int)u.size();
  233. int t = 1;
  234. while(t < n) t <<= 1;
  235. vector<vector<vector<int> > > p;
  236. const vector<int> inv_s = _pre_computation(u, v, p);
  237. vector<int> vs(n);
  238. for(int i = 0; i < n; ++i) vs[i] = mul(v[i], inverse(inv_s[i]));
  239. vector<int> res;
  240. recursive_polynomial_interpolation(u, vs, res, p, n, t, 0, 0, 0);
  241. return res;
  242. }

verify 用の問題

yosupo さんの library checker : Polynomial Interpolation 提出コード