

DFT <==> IDFT
https://leetcode.com/problems/multiply-strings/discuss/2053278/fft-and-ntt-solutions
By Long Luo
FFT:
Implementation of FFT algorithm
FFT Optimizations
class Solution {
public:
const double PI = acos(-1.0); // PI = arccos(-1)
struct Complex {
double re, im;
Complex(double _re = 0.0, double _im = 0.0) {
re = _re;
im = _im;
}
inline void real(const double &re) {
this->re = re;
}
inline double real() {
return re;
}
inline void imag(const double &im) {
this->im = im;
}
inline double imag() {
return im;
}
inline Complex operator-(const Complex &other) const {
return Complex(re - other.re, im - other.im);
}
inline Complex operator+(const Complex &other) const {
return Complex(re + other.re, im + other.im);
}
inline Complex operator*(const Complex &other) const {
return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}
inline void operator/(const double &div) {
re /= div;
im /= div;
}
inline void operator*=(const Complex &other) {
*this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}
inline void operator+=(const Complex &other) {
this->re += other.re;
this->im += other.im;
}
inline Complex conjugate() {
return Complex(re, -im);
}
};
vector<Complex> FFT(vector<Complex> &a, bool invert) {
int n = a.size();
if (n == 1) {
return a;
}
vector<Complex> Pe(n / 2), Po(n / 2);
for (int i = 0; 2 * i < n; i++) {
Pe[i] = a[2 * i];
Po[i] = a[2 * i + 1];
}
vector<Complex> ye = FFT(Pe, invert);
vector<Complex> yo = FFT(Po, invert);
// Combine
vector<Complex> y(n);
// Root of Units
double ang = 2 * PI / n * (invert ? -1 : 1);
Complex wn(cos(ang), sin(ang));
Complex w(1, 0);
for (int i = 0; i < n / 2; i++) {
y[i] = ye[i] + w * yo[i];
y[i + n / 2] = ye[i] - w * yo[i];
w = w * wn;
}
return y;
}
}class Solution {
public:
const double PI = acos(-1.0); // PI = arccos(-1)
struct Complex {
double re, im;
Complex(double _re = 0.0, double _im = 0.0) {
re = _re;
im = _im;
}
inline void real(const double &re) {
this->re = re;
}
inline double real() {
return re;
}
inline void imag(const double &im) {
this->im = im;
}
inline double imag() {
return im;
}
inline Complex operator-(const Complex &other) const {
return Complex(re - other.re, im - other.im);
}
inline Complex operator+(const Complex &other) const {
return Complex(re + other.re, im + other.im);
}
inline Complex operator*(const Complex &other) const {
return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}
inline void operator/(const double &div) {
re /= div;
im /= div;
}
inline void operator+=(const Complex &other) {
this->re += other.re;
this->im += other.im;
}
inline void operator-=(const Complex &other) {
this->re -= other.re;
this->im -= other.im;
}
inline void operator*=(const Complex &other) {
*this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
}
inline Complex conjugate() {
return Complex(re, -im);
}
};
static const int N = 256;
Complex omega[N];
Complex invert[N];
int rev[N];
void init(int n) {
rev[0] = 0;
for (int i = 0; i < n; i++) {
double ang = 2 * PI * i / n;
omega[i] = Complex(cos(ang), sin(ang));
invert[i] = omega[i].conjugate();
if (i > 0) {
rev[i] = rev[i >> 1] >> 1;
if (i & 1) {
rev[i] |= n >> 1;
}
}
}
}
void FFT(vector<Complex> &a, Complex *omega) {
int n = a.size();
if (n == 1) {
return;
}
for (int i = 0; i < n; ++i) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int len = 2; len <= n; len *= 2) {
for (int i = 0; i < n; i += len) {
for (int j = 0; j < len / 2; j++) {
Complex u = a[i + j];
Complex v = omega[j * n / len] * a[i + j + len / 2];
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
}
}
}
}
}O((m+n)log(m+n)).O(m+n).NTT : 快速数论变换(Number Theoretic Transform) 。
class Solution {
public:
const long long G = 3;
const long long G_INV = 332748118;
const long long MOD = 998244353;
vector<int> rev;
long long quickPower(long long a, long long b) {
long long res = 1;
while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}
a = (a * a) % MOD;
b >>= 1;
}
return res % MOD;
}
void ntt(vector<long long> &a, bool invert) {
int n = a.size();
if (n == 1) {
return;
}
vector<long long> Pe(n / 2), Po(n / 2);
for (int i = 0; 2 * i < n; i++) {
Pe[i] = a[2 * i];
Po[i] = a[2 * i + 1];
}
ntt(Pe, invert);
ntt(Po, invert);
long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n);
long long w = 1;
for (int i = 0; i < n / 2; i++) {
a[i] = Pe[i] + w * Po[i] % MOD;
a[i] = (a[i] % MOD + MOD) % MOD;
a[i + n / 2] = Pe[i] - w * Po[i] % MOD;
a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD;
w = w * wn % MOD;
}
}
}class Solution {
static const long long MOD = 998244353;
static const long long G = 3;
static const int G_INV = 332748118;
vector<int> rev;
public:
long long quickPower(long long a, long long b) {
long long res = 1;
while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}
a = (a * a) % MOD;
b >>= 1;
}
return res % MOD;
}
void ntt(vector<long long> &a, bool invert = false) {
int n = a.size();
for (int i = 0; i < n; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int len = 2; len <= n; len <<= 1) {
long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);
for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j];
long long v = (w * a[i + j + len / 2]) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (MOD + u - v) % MOD;
w = (w * wlen) % MOD;
}
}
}
if (invert) {
long long inver = quickPower(n, MOD - 2);
for (int i = 0; i < n; i++) {
a[i] = (long long) a[i] * inver % MOD;
}
}
}
}O((m+n)log(m+n)).O(m+n).



