0

我正在尝试从 Python 中的 Matlab 移植以下 ICA 实现。据我了解,它使用带有双曲余弦作为对比函数的紧缩正交化。

通过 sklearn使用FastICA 可以获得令人满意的结果,但与 Matlab 相比执行时间非常慢。

作为比较,以下示例数据的执行时间如下:

  • Python(deflation算法):4.97 秒
  • Python(parallel算法):0.04 秒
  • Matlab:0.04 秒

奇怪的deflation是,Python 中的 FastICA 算法比 Matlab 实现或 Python 中的 FastICA 算法慢 100 倍以上paralell

为什么会有这种巨大的差异,尤其是 Matlab 和 Python 版本之间的差异?我不是 ICA 专家,因此可能缺少我的配置。

这是用于生成示例数据和分析执行时间的 Python 代码:

import timeit
import numpy as np
from sklearn.decomposition import FastICA
from scipy.misc import electrocardiogram

# prepare signal
ecg = electrocardiogram().reshape(-1, 1)
np.random.seed(0)
ecg_noisy = ecg + np.random.randn(ecg.shape[0], 1)
x = np.hstack((ecg, ecg_noisy))

n = 10  # number of runs for profiling

# profile parallel FastICA algoritm
transformerParallel = FastICA(n_components=x.shape[1],
                              algorithm='parallel',  
                              random_state=0,
                              whiten=True,
                              max_iter=200,
                              fun='logcosh',
                              tol=1E-4)                            
tp = timeit.timeit(lambda: transformerParallel.fit_transform(x), number=n)
print('ICA Parallel takes {:.3f} seconds'.format(tp/n));

# profile deflational FastICA algorithm
transformerDeflation = FastICA(n_components=x.shape[1],
                               algorithm='deflation',  
                               random_state=0,
                               whiten=True,
                               max_iter=200,
                               fun='logcosh',
                               tol=1E-4)                               
td = timeit.timeit(lambda: transformerDeflation.fit_transform(x), number=n)
print('ICA Deflation takes {:.3f} seconds'.format(td/n));

# export data for profiling in Matlab
import pandas as pd
pd.DataFrame(x).to_csv('input.csv', header=False, index=False)

这是用于在 Matlab 中分析的代码(使用coshFpDeIca.m):

x = load('input.csv');
tm = timeit(@()coshFpDeIca(x'), 3);
fprintf('ICA Deflation (in Matlab) takes %.3f seconds\n', tm);
4

0 回答 0