1706. 高精度乘法

使用洛谷题目 P1932 MashPlant 大佬的模板:

注意复杂度是 O(n^2) 的,如果想要优化到 O(n\log n) 需要使用 FFT 。

朴素:

// luogu-judger-enable-o2 #include <cstring> #include <cstdio> #include <algorithm> #include <cassert> typedef int i32; typedef unsigned u32; typedef unsigned long long u64; struct BigInt { const static u32 exp = 9; const static u32 mod = 1000000000; static i32 abs_comp(const BigInt &lhs, const BigInt &rhs) { if (lhs.len != rhs.len) return lhs.len < rhs.len ? -1 : 1; for (u32 i = lhs.len - 1; ~i; --i) if (lhs.val[i] != rhs.val[i]) return lhs.val[i] < rhs.val[i] ? -1 : 1; return 0; } u32 *val, len, sgn; BigInt(u32 *val = nullptr, u32 len = 0, u32 sgn = 0) : val(val), len(len), sgn(sgn) {} // copy_to cannot guarantee val[x] == 0 for x >= len // other function should set (the position they assume to be zero) as zero void copy_to(BigInt &dst) const { dst.len = len, dst.sgn = sgn; memcpy(dst.val, val, sizeof(u32) * len); } void trim() { while (len && !val[len - 1]) --len; if (len == 0) sgn = 0; } void add(BigInt &x) { if (sgn ^ x.sgn) return x.sgn ^= 1, sub(x); val[len = std::max(len, x.len)] = 0; for (u32 i = 0; i < x.len; ++i) if ((val[i] += x.val[i]) >= mod) val[i] -= mod, ++val[i + 1]; for (u32 i = x.len; i < len && val[i] >= mod; ++i) val[i] -= mod, ++val[i + 1]; if (val[len]) ++len; } void sub(BigInt &x) { if (sgn ^ x.sgn) return x.sgn ^= 1, add(x); if (abs_comp(*this, x) < 0) std::swap(*this, x), sgn ^= 1; val[len] = 0; for (u32 i = 0; i < x.len; ++i) if ((val[i] -= x.val[i]) & 0x80000000) val[i] += mod, --val[i + 1]; for (u32 i = x.len; i < len && val[i] & 0x80000000; ++i) val[i] += mod, --val[i + 1]; trim(); } void mul(BigInt &x, u32 *ext_mem) { assert(this != &x); memcpy(ext_mem, val, sizeof(u32) * len); memset(val, 0, sizeof(u32) * (len + x.len)); for (u32 i = 0; i < len; ++i) for (u32 j = 0; j < x.len; ++j) { u64 tmp = (u64)ext_mem[i] * x.val[j] + val[i + j]; val[i + j] = tmp % mod; val[i + j + 1] += tmp / mod; } len += x.len, sgn ^= x.sgn; trim(); } void mul(u32 x) { if (x & 0x80000000) x = -x, sgn ^= 1; u64 carry = 0; for (u32 i = 0; i < len; ++i) { carry += (u64)val[i] * x; val[i] = carry % mod; carry /= mod; } if (carry) val[len++] = carry; trim(); } void div(BigInt &x, BigInt &remainder, u32 *ext_mem) { assert(this != &x && this != &remainder); copy_to(remainder), memset(val, 0, sizeof(u32) * len); u32 shift = len - x.len; if (shift & 0x80000000) return void(len = sgn = 0); while (~shift) { u32 l = 0, r = mod; BigInt mul_test{ext_mem}, remainder_high{remainder.val + shift, remainder.len - shift}; while (l <= r) { u32 mid = (l + r) / 2; x.copy_to(mul_test), mul_test.mul(mid); abs_comp(mul_test, remainder_high) < 0 ? l = mid + 1 : r = mid - 1; } val[shift] = r; x.copy_to(mul_test), mul_test.mul(r); remainder_high.sub(mul_test), remainder.trim(); --shift; } sgn ^= x.sgn; trim(); } void div(u32 x) { if (x & 0x80000000) x = -x, sgn ^= 1; u64 carry = 0; for (u32 i = len - 1; ~i; --i) { carry = carry * mod + val[i]; val[i] = carry / x; carry %= x; } trim(); } void read(const char *s) { sgn = len = 0; i32 bound = 0, pos; if (s[0] == '-') sgn = bound = 1; u64 cur = 0, pow = 1; for (pos = strlen(s) - 1; pos + 1 >= exp + bound; pos -= exp, val[len++] = cur, cur = 0, pow = 1) for (i32 i = pos; i + exp > pos; --i) cur += (s[i] - '0') * pow, pow *= 10; for (cur = 0, pow = 1; pos >= bound; --pos) cur += (s[pos] - '0') * pow, pow *= 10; if (cur) val[len++] = cur; } void print() { if (len) { if (sgn) putchar('-'); printf("%u", val[len - 1]); for (u32 i = len - 2; ~i; --i) printf("%0*u", exp, val[i]); } else putchar('0'); puts(""); } }; const int N = 1e5 + 20; u32 a_[N], b_[N], r_[N], tmp[N * 2]; char sa[N], sb[N]; int main() { scanf("%s%s", sa, sb); { BigInt a{a_}, b{b_}; a.read(sa), b.read(sb), a.mul(b, tmp), a.print(); } }

FFT:

#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef double db; #define sc(x) x = read() #define mn ((1 << 18) + 1) ll n1, n2, rev[mn], ans[mn], k, s = 1, len, n; db pi = acos(-1), v; typedef complex<db> cp; cp a[mn], b[mn]; char s1[mn], s2[mn]; void fft(cp *a, ll n, ll flag) { for (ll i = 0; i < n; ++i) { if (i < rev[i]) { swap(a[i], a[rev[i]]); } } for (ll h = 1; h < n; h <<= 1) { cp wn = exp(cp(0, flag * pi / h)); for (ll j = 0; j < n; j += h << 1) { cp w(1, 0); for (ll k = j; k < j + h; ++k) { cp x = a[k], y = w * a[k + h]; a[k] = x + y; a[k + h] = x - y; w *= wn; } } } if (flag == -1) { for (ll i = 0; i < n; ++i) { a[i] /= n; } } } signed main() { scanf("%s%s", s1, s2); n1 = strlen(s1), n2 = strlen(s2), n = max(n1, n2); for (ll i = 0; i < n1; ++i) { a[i] = (db)(s1[n1 - i - 1] - '0'); } for (ll i = 0; i < n2; ++i) { b[i] = (db)(s2[n2 - i - 1] - '0'); } k = 1, s = 2; while ((1 << k) < (n << 1) - 1) { ++k, s <<= 1; } // while (s <= n) // { // s <<= 1, ++k; // } for (ll i = 0; i < s; ++i) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); } fft(a, s, 1), fft(b, s, 1); for (ll i = 0; i <= s; ++i) { a[i] *= b[i]; } fft(a, s, -1); for (ll i = 0; i < s; ++i) { ans[i] += (ll)(a[i].real() + 0.5); ans[i + 1] += ans[i] / 10, ans[i] %= 10; } while (!ans[s] && s > -1) { --s; } if (s == -1) { puts("0"); } else { for (ll i = s; i >= 0; --i) { printf("%lld", ans[i]); } } return 0; }