My FFT Algorithm Learning Notes
  1. FFT

FT

Fourier Transform

DFT <==> IDFT


https://leetcode.com/problems/multiply-strings/discuss/2053278/fft-and-ntt-solutions

By Long Luo

FFT

FFT:

Implementation of FFT algorithm
FFT Optimizations

Recursion

class Solution {
public:
    const double PI = acos(-1.0);  // PI = arccos(-1)

    struct Complex {
        double re, im;

        Complex(double _re = 0.0, double _im = 0.0) {
            re = _re;
            im = _im;
        }

        inline void real(const double &re) {
            this->re = re;
        }

        inline double real() {
            return re;
        }

        inline void imag(const double &im) {
            this->im = im;
        }

        inline double imag() {
            return im;
        }

        inline Complex operator-(const Complex &other) const {
            return Complex(re - other.re, im - other.im);
        }

        inline Complex operator+(const Complex &other) const {
            return Complex(re + other.re, im + other.im);
        }

        inline Complex operator*(const Complex &other) const {
            return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
        }

        inline void operator/(const double &div) {
            re /= div;
            im /= div;
        }

        inline void operator*=(const Complex &other) {
            *this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
        }

        inline void operator+=(const Complex &other) {
            this->re += other.re;
            this->im += other.im;
        }

        inline Complex conjugate() {
            return Complex(re, -im);
        }
    };

    vector<Complex> FFT(vector<Complex> &a, bool invert) {
        int n = a.size();

        if (n == 1) {
            return a;
        }

        vector<Complex> Pe(n / 2), Po(n / 2);

        for (int i = 0; 2 * i < n; i++) {
            Pe[i] = a[2 * i];
            Po[i] = a[2 * i + 1];
        }

        vector<Complex> ye = FFT(Pe, invert);
        vector<Complex> yo = FFT(Po, invert);

        // Combine
        vector<Complex> y(n);

        // Root of Units
        double ang = 2 * PI / n * (invert ? -1 : 1);
        Complex wn(cos(ang), sin(ang)); 
        Complex w(1, 0);  

        for (int i = 0; i < n / 2; i++) {
            y[i] = ye[i] + w * yo[i]; 
            y[i + n / 2] = ye[i] - w * yo[i];
            w = w * wn; 
        }

        return y;  
    }
}

Iteration

class Solution {
public:
    const double PI = acos(-1.0);    // PI = arccos(-1)

    struct Complex {
        double re, im;

        Complex(double _re = 0.0, double _im = 0.0) {
            re = _re;
            im = _im;
        }

        inline void real(const double &re) {
            this->re = re;
        }

        inline double real() {
            return re;
        }

        inline void imag(const double &im) {
            this->im = im;
        }

        inline double imag() {
            return im;
        }

        inline Complex operator-(const Complex &other) const {
            return Complex(re - other.re, im - other.im);
        }

        inline Complex operator+(const Complex &other) const {
            return Complex(re + other.re, im + other.im);
        }

        inline Complex operator*(const Complex &other) const {
            return Complex(re * other.re - im * other.im, re * other.im + im * other.re);
        }

        inline void operator/(const double &div) {
            re /= div;
            im /= div;
        }

        inline void operator+=(const Complex &other) {
            this->re += other.re;
            this->im += other.im;
        }

        inline void operator-=(const Complex &other) {
            this->re -= other.re;
            this->im -= other.im;
        }

        inline void operator*=(const Complex &other) {
            *this = Complex(re * other.re - im * other.im, re * other.im + im * other.re);
        }

        inline Complex conjugate() {
            return Complex(re, -im);
        }
    };

    static const int N = 256;

    Complex omega[N];
    Complex invert[N];

    int rev[N];

    void init(int n) {
        rev[0] = 0;

        for (int i = 0; i < n; i++) {
            double ang = 2 * PI * i / n;
            omega[i] = Complex(cos(ang), sin(ang));
            invert[i] = omega[i].conjugate();

            if (i > 0) {
                rev[i] = rev[i >> 1] >> 1;
                if (i & 1) {
                    rev[i] |= n >> 1;
                }
            }
        }
    }

    void FFT(vector<Complex> &a, Complex *omega) {
        int n = a.size();

        if (n == 1) {
            return;
        }

        for (int i = 0; i < n; ++i) {
            if (i < rev[i]) {
                swap(a[i], a[rev[i]]);
            }
        }

        for (int len = 2; len <= n; len *= 2) {
            for (int i = 0; i < n; i += len) {
                for (int j = 0; j < len / 2; j++) {
                    Complex u = a[i + j];
                    Complex v = omega[j * n / len] * a[i + j + len / 2];
                    a[i + j] = u + v;
                    a[i + j + len / 2] = u - v;
                }
            }
        }
    }
}

Analysis

  • Time Complexity: O((m+n)log(m+n)).
  • Space Complexity: O(m+n).

Number Theoretic Transform

NTT : 快速数论变换(Number Theoretic Transform)

Recursion

class Solution {

public:
    const long long G = 3;
    const long long G_INV = 332748118;
    const long long MOD = 998244353;

    vector<int> rev;

    long long quickPower(long long a, long long b) {
        long long res = 1;

        while (b > 0) {
            if (b & 1) {
                res = (res * a) % MOD;
            }

            a = (a * a) % MOD;
            b >>= 1;
        }

        return res % MOD;
    }

    void ntt(vector<long long> &a, bool invert) {
        int n = a.size();

        if (n == 1) {
            return;
        }

        vector<long long> Pe(n / 2), Po(n / 2);

        for (int i = 0; 2 * i < n; i++) {
            Pe[i] = a[2 * i];
            Po[i] = a[2 * i + 1];
        }

        ntt(Pe, invert);
        ntt(Po, invert);

        long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n);
        long long w = 1;

        for (int i = 0; i < n / 2; i++) {
            a[i] = Pe[i] + w * Po[i] % MOD;
            a[i] = (a[i] % MOD + MOD) % MOD;
            a[i + n / 2] = Pe[i] - w * Po[i] % MOD;
            a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD;
            w = w * wn % MOD;
        }
    }
}

Iteration

class Solution {
    static const long long MOD = 998244353;
    static const long long G = 3;
    static const int G_INV = 332748118;
    vector<int> rev;

public:
    long long quickPower(long long a, long long b) {
        long long res = 1;

        while (b > 0) {
            if (b & 1) {
                res = (res * a) % MOD;
            }

            a = (a * a) % MOD;
            b >>= 1;
        }

        return res % MOD;
    }

    void ntt(vector<long long> &a, bool invert = false) {
        int n = a.size();

        for (int i = 0; i < n; i++) {
            if (i < rev[i]) {
                swap(a[i], a[rev[i]]);
            }
        }

        for (int len = 2; len <= n; len <<= 1) {
            long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);

            for (int i = 0; i < n; i += len) {
                long long w = 1;
                for (int j = 0; j < len / 2; j++) {
                    long long u = a[i + j];
                    long long v = (w * a[i + j + len / 2]) % MOD;
                    a[i + j] = (u + v) % MOD;
                    a[i + j + len / 2] = (MOD + u - v) % MOD;
                    w = (w * wlen) % MOD;
                }
            }
        }

        if (invert) {
            long long inver = quickPower(n, MOD - 2);
            for (int i = 0; i < n; i++) {
                a[i] = (long long) a[i] * inver % MOD;
            }
        }
    }
}

Analysis

  • Time Complexity: O((m+n)log(m+n)).
  • Space Complexity: O(m+n).

Fourier

Fourier Series

image
image
image

Comments (1)