我正在编写一个函数来使用 Stockham FFT 算法计算快速傅里叶变换,并发现如果 FFT 的长度是 2 的幂,则可以在编译时预先计算所有用于计算的“旋转因子”。
在 FFT 计算中,旋转因子计算通常占用总时间的很大一部分,因此理论上这样做应该会大大提高算法的性能。
昨天我花了一整天时间在一个新的编译器 (gcc 10) 上重新实现我的算法,这样我就可以使用 C++20consteval
功能在编译时预先计算所有的旋转因子。我成功地做到了,但最终在编译时预先计算所有旋转因子的代码实际上运行得更慢!
这是在运行时执行所有计算的代码:
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <complex>
#include <iostream>
#include <vector>
using namespace std;
static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x);
constexpr bool IsPowerOf2(const size_t value)
{
return value && (!(value & (value - 1)));
}
vector<complex<double>> FFT(const vector<double>& x)
{
const auto N = x.size();
assert(IsPowerOf2(x.size()));
const auto NOver2 = N/2;
vector<complex<double>> x_p(N);
transform(x.begin(), x.end(), x_p.begin(), [](const double value) {
return complex<double>(value);
});
return StockhamFFT(x_p);
}
// C++ implementation of the Stockam FFT algorithm
static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x)
{
const auto N = x.size();
assert(IsPowerOf2(N));
const auto NOver2 = N/2;
// Pre-calculate the twiddle factors (at runtime)
vector<complex<double>> W(NOver2);
const auto omega = 2.0 * M_PI / N;
for (size_t n = 0; n < NOver2; ++n)
{
W[n] = complex{ cos(-omega * n), sin(-omega * n) };
}
// The Stockham algorithm requires one vector for input/output data and
// another as a temporary workspace
vector<complex<double>> a(x);
vector<complex<double>> b(N);
// Set the spacing between twiddle factors used at the first stage
size_t WStride = N/2;
// Loop through each stage of the FFT
for (size_t stride = 1; stride < N; stride *= 2)
{
// Loop through the individual FFTs of each stage
for (size_t m = 0; m < NOver2; m += stride)
{
const auto mTimes2 = m*2;
// Perform each individual FFT
for (size_t n = 0; n < stride; ++n)
{
// Calculate the input indexes
const auto aIndex1 = n + m;
const auto aIndex2 = aIndex1 + NOver2;
// Calculate the output indexes
const auto bIndex1 = n + mTimes2;
const auto bIndex2 = bIndex1 + stride;
// Perform the FFT
const auto tmp1 = a[aIndex1];
const auto tmp2 = W[n*WStride]*a[aIndex2];
// Sum the results
b[bIndex1] = tmp1 + tmp2;
b[bIndex2] = tmp1 - tmp2; // (>*.*)> symmetry! <(*.*<)
}
}
// Spacing between twiddle factors is half for the next stage
WStride /= 2;
// Swap the data (output of this stage is input of the next)
a.swap(b);
}
return a;
}
int main()
{
size_t N = pow(2, 18);
vector<double> x(N);
int f_s = 1000;
double t_s = 1.0 / f_s;
for (size_t n = 0; n < N; ++n)
{
x[n] = sin(2 * M_PI * 100 * n * t_s);
}
auto sum = 0;
for (int i = 1; i < 100; ++i)
{
auto start = chrono::high_resolution_clock::now();
auto X = FFT(x);
auto stop = chrono::high_resolution_clock::now();
auto duration = chrono::duration_cast<chrono::microseconds>(stop - start);
sum += duration.count();
}
auto average = sum / 100;
std::cout << "duration " << average << " microseconds." << std::endl;
}
以此为起点,我能够从StockhamFFT
函数中提取旋转因子计算,并在编译时使用consteval
函数执行它们。这是代码之后的样子:
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cmath>
#include <complex>
#include <iostream>
#include <vector>
using namespace std;
static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x);
constexpr bool IsPowerOf2(const size_t value)
{
return value && (!(value & (value - 1)));
}
// Evaluates FFT twiddle factors at compile time!
template <size_t N>
static consteval array<complex<double>, N/2> CalculateTwiddleFactors()
{
static_assert(IsPowerOf2(N), "N must be a power of 2.");
array<complex<double>, N/2> W;
const auto omega = 2.0*M_PI/N;
for (size_t n = 0; n < N/2; ++n)
{
W[n] = complex{cos(-omega*n), sin(-omega*n)};
}
return W;
}
// Calculate the twiddle factors (>*O*)> AT COMPILE TIME <(*O*<)
constexpr auto W = CalculateTwiddleFactors<static_cast<size_t>(pow(2,18))>();
vector<complex<double>> FFT(const vector<double>& x)
{
const auto N = x.size();
assert(IsPowerOf2(x.size()));
const auto NOver2 = N/2;
vector<complex<double>> x_p(N);
transform(x.begin(), x.end(), x_p.begin(), [](const double value) {
return complex<double>(value);
});
return StockhamFFT(x_p);
}
// C++ implementation of the Stockam FFT algorithm
static vector<complex<double>> StockhamFFT(const vector<complex<double>>& x)
{
const auto N = x.size();
assert(IsPowerOf2(N));
const auto NOver2 = N/2;
//***********************************************************************
// Twiddle factors are already calculated at compile time!!!
//***********************************************************************
// The Stockham algorithm requires one vector for input/output data and
// another as a temporary workspace
vector<complex<double>> a(x);
vector<complex<double>> b(N);
// Set the spacing between twiddle factors used at the first stage
size_t WStride = N/2;
// Loop through each stage of the FFT
for (size_t stride = 1; stride < N; stride *= 2)
{
// Loop through the individual FFTs of each stage
for (size_t m = 0; m < NOver2; m += stride)
{
const auto mTimes2 = m*2;
// Perform each individual FFT
for (size_t n = 0; n < stride; ++n)
{
// Calculate the input indexes
const auto aIndex1 = n + m;
const auto aIndex2 = aIndex1 + NOver2;
// Calculate the output indexes
const auto bIndex1 = n + mTimes2;
const auto bIndex2 = bIndex1 + stride;
// Perform the FFT
const auto tmp1 = a[aIndex1];
const auto tmp2 = W[n*WStride]*a[aIndex2];
// Sum the results
b[bIndex1] = tmp1 + tmp2;
b[bIndex2] = tmp1 - tmp2; // (>*.*)> symmetry! <(*.*<)
}
}
// Spacing between twiddle factors is half for the next stage
WStride /= 2;
// Swap the data (output of this stage is input of the next)
a.swap(b);
}
return a;
}
int main()
{
size_t N = pow(2, 18);
vector<double> x(N);
int f_s = 1000;
double t_s = 1.0 / f_s;
for (size_t n = 0; n < N; ++n)
{
x[n] = sin(2 * M_PI * 100 * n * t_s);
}
auto sum = 0;
for (int i = 1; i < 100; ++i)
{
auto start = chrono::high_resolution_clock::now();
auto X = FFT(x);
auto stop = chrono::high_resolution_clock::now();
auto duration = chrono::duration_cast<chrono::microseconds>(stop - start);
sum += duration.count();
}
auto average = sum / 100;
std::cout << "duration " << average << " microseconds." << std::endl;
}
这两个版本都是在 Ubuntu 19.10 上使用 gcc 10.0.1 编译的:
g++ -std=c++2a -o main main.cpp
请注意,gcc 编译器是特别需要的,因为它是唯一constexpr
支持sin
和cos
“运行时”示例产生以下结果:
持续时间 292854 微秒。
“编译时”示例产生以下内容:
持续时间 295230 微秒。
编译时版本确实花费了更长的时间来编译,但不知何故仍然需要更长的时间来运行,即使大多数计算在程序开始之前就已经完成了!这怎么可能?我在这里遗漏了一些关键的东西吗?