使用洛谷题目 P1932 MashPlant
大佬的模板:
注意复杂度是 O(n^2) 的,如果想要优化到 O(n\log n) 需要使用 FFT 。
朴素:
// luogu-judger-enable-o2
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cassert>
typedef int i32;
typedef unsigned u32;
typedef unsigned long long u64;
struct BigInt
{
const static u32 exp = 9;
const static u32 mod = 1000000000;
static i32 abs_comp(const BigInt &lhs, const BigInt &rhs)
{
if (lhs.len != rhs.len)
return lhs.len < rhs.len ? -1 : 1;
for (u32 i = lhs.len - 1; ~i; --i)
if (lhs.val[i] != rhs.val[i])
return lhs.val[i] < rhs.val[i] ? -1 : 1;
return 0;
}
u32 *val, len, sgn;
BigInt(u32 *val = nullptr, u32 len = 0, u32 sgn = 0) : val(val), len(len), sgn(sgn) {}
// copy_to cannot guarantee val[x] == 0 for x >= len
// other function should set (the position they assume to be zero) as zero
void copy_to(BigInt &dst) const
{
dst.len = len, dst.sgn = sgn;
memcpy(dst.val, val, sizeof(u32) * len);
}
void trim()
{
while (len && !val[len - 1])
--len;
if (len == 0)
sgn = 0;
}
void add(BigInt &x)
{
if (sgn ^ x.sgn)
return x.sgn ^= 1, sub(x);
val[len = std::max(len, x.len)] = 0;
for (u32 i = 0; i < x.len; ++i)
if ((val[i] += x.val[i]) >= mod)
val[i] -= mod, ++val[i + 1];
for (u32 i = x.len; i < len && val[i] >= mod; ++i)
val[i] -= mod, ++val[i + 1];
if (val[len])
++len;
}
void sub(BigInt &x)
{
if (sgn ^ x.sgn)
return x.sgn ^= 1, add(x);
if (abs_comp(*this, x) < 0)
std::swap(*this, x), sgn ^= 1;
val[len] = 0;
for (u32 i = 0; i < x.len; ++i)
if ((val[i] -= x.val[i]) & 0x80000000)
val[i] += mod, --val[i + 1];
for (u32 i = x.len; i < len && val[i] & 0x80000000; ++i)
val[i] += mod, --val[i + 1];
trim();
}
void mul(BigInt &x, u32 *ext_mem)
{
assert(this != &x);
memcpy(ext_mem, val, sizeof(u32) * len);
memset(val, 0, sizeof(u32) * (len + x.len));
for (u32 i = 0; i < len; ++i)
for (u32 j = 0; j < x.len; ++j)
{
u64 tmp = (u64)ext_mem[i] * x.val[j] + val[i + j];
val[i + j] = tmp % mod;
val[i + j + 1] += tmp / mod;
}
len += x.len, sgn ^= x.sgn;
trim();
}
void mul(u32 x)
{
if (x & 0x80000000)
x = -x, sgn ^= 1;
u64 carry = 0;
for (u32 i = 0; i < len; ++i)
{
carry += (u64)val[i] * x;
val[i] = carry % mod;
carry /= mod;
}
if (carry)
val[len++] = carry;
trim();
}
void div(BigInt &x, BigInt &remainder, u32 *ext_mem)
{
assert(this != &x && this != &remainder);
copy_to(remainder), memset(val, 0, sizeof(u32) * len);
u32 shift = len - x.len;
if (shift & 0x80000000)
return void(len = sgn = 0);
while (~shift)
{
u32 l = 0, r = mod;
BigInt mul_test{ext_mem}, remainder_high{remainder.val + shift, remainder.len - shift};
while (l <= r)
{
u32 mid = (l + r) / 2;
x.copy_to(mul_test), mul_test.mul(mid);
abs_comp(mul_test, remainder_high) < 0 ? l = mid + 1 : r = mid - 1;
}
val[shift] = r;
x.copy_to(mul_test), mul_test.mul(r);
remainder_high.sub(mul_test), remainder.trim();
--shift;
}
sgn ^= x.sgn;
trim();
}
void div(u32 x)
{
if (x & 0x80000000)
x = -x, sgn ^= 1;
u64 carry = 0;
for (u32 i = len - 1; ~i; --i)
{
carry = carry * mod + val[i];
val[i] = carry / x;
carry %= x;
}
trim();
}
void read(const char *s)
{
sgn = len = 0;
i32 bound = 0, pos;
if (s[0] == '-')
sgn = bound = 1;
u64 cur = 0, pow = 1;
for (pos = strlen(s) - 1; pos + 1 >= exp + bound; pos -= exp, val[len++] = cur, cur = 0, pow = 1)
for (i32 i = pos; i + exp > pos; --i)
cur += (s[i] - '0') * pow, pow *= 10;
for (cur = 0, pow = 1; pos >= bound; --pos)
cur += (s[pos] - '0') * pow, pow *= 10;
if (cur)
val[len++] = cur;
}
void print()
{
if (len)
{
if (sgn)
putchar('-');
printf("%u", val[len - 1]);
for (u32 i = len - 2; ~i; --i)
printf("%0*u", exp, val[i]);
}
else
putchar('0');
puts("");
}
};
const int N = 1e5 + 20;
u32 a_[N], b_[N], r_[N], tmp[N * 2];
char sa[N], sb[N];
int main()
{
scanf("%s%s", sa, sb);
{
BigInt a{a_}, b{b_};
a.read(sa), b.read(sb), a.mul(b, tmp), a.print();
}
}
FFT:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define sc(x) x = read()
#define mn ((1 << 18) + 1)
ll n1, n2, rev[mn], ans[mn], k, s = 1, len, n;
db pi = acos(-1), v;
typedef complex<db> cp;
cp a[mn], b[mn];
char s1[mn], s2[mn];
void fft(cp *a, ll n, ll flag)
{
for (ll i = 0; i < n; ++i)
{
if (i < rev[i])
{
swap(a[i], a[rev[i]]);
}
}
for (ll h = 1; h < n; h <<= 1)
{
cp wn = exp(cp(0, flag * pi / h));
for (ll j = 0; j < n; j += h << 1)
{
cp w(1, 0);
for (ll k = j; k < j + h; ++k)
{
cp x = a[k], y = w * a[k + h];
a[k] = x + y;
a[k + h] = x - y;
w *= wn;
}
}
}
if (flag == -1)
{
for (ll i = 0; i < n; ++i)
{
a[i] /= n;
}
}
}
signed main()
{
scanf("%s%s", s1, s2);
n1 = strlen(s1), n2 = strlen(s2), n = max(n1, n2);
for (ll i = 0; i < n1; ++i)
{
a[i] = (db)(s1[n1 - i - 1] - '0');
}
for (ll i = 0; i < n2; ++i)
{
b[i] = (db)(s2[n2 - i - 1] - '0');
}
k = 1, s = 2;
while ((1 << k) < (n << 1) - 1)
{
++k, s <<= 1;
}
// while (s <= n)
// {
// s <<= 1, ++k;
// }
for (ll i = 0; i < s; ++i)
{
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
}
fft(a, s, 1), fft(b, s, 1);
for (ll i = 0; i <= s; ++i)
{
a[i] *= b[i];
}
fft(a, s, -1);
for (ll i = 0; i < s; ++i)
{
ans[i] += (ll)(a[i].real() + 0.5);
ans[i + 1] += ans[i] / 10, ans[i] %= 10;
}
while (!ans[s] && s > -1)
{
--s;
}
if (s == -1)
{
puts("0");
}
else
{
for (ll i = s; i >= 0; --i)
{
printf("%lld", ans[i]);
}
}
return 0;
}