2

UPD:谢谢,它有效。

我有一个表示直方图的一维向量。它看起来像几个高斯函数的总和: 在此处输入图像描述

curve_fit在 SO 上找到了示例代码,但不知道如何修改它以接收更多高斯元组(mu,sigma)。我听说“curve_fit”只优化一个函数(在这种情况下是一个高斯曲线)。

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def estimate_sigma(hist):
    bin_edges = np.arange(len(hist))
    bin_centres = bin_edges + 0.5

    # Define model function to be used to fit to the data above:
    def gauss(x, *p):
        A, mu, sigma = p
        return A*numpy.exp(-(x-mu)**2/(2.*sigma**2))

    # p0 is the initial guess for the fitting coefficients (A, mu and sigma above)
    p0 = [1., 0., 1.]

    coeff, var_matrix = curve_fit(gauss, bin_centres, hist, p0=p0)

    # Get the fitted curve
    hist_fit = gauss(bin_centres, *coeff)

    plt.plot(bin_centres, hist, label='Test data')
    plt.plot(bin_centres, hist_fit, label='Fitted data')

    print 'Fitted mean = ', coeff[1]
    coeff2 =coeff[2]
    print 'Fitted standard deviation = ', coeff2

    plt.show()

这个函数找到一个高斯曲线,而视觉上有 3 或 4 个: 在此处输入图像描述

拜托,你能建议一些 numpy/scipy 函数来实现1D vector格式的 gmm 表示([m1, sigma1],[m2, sigma2],..,[mN,sigmaN])吗?

4

1 回答 1

0

正如tBuLi建议的那样,我将额外的高斯曲线系数传递给gauss以及curve_fit。现在拟合曲线看起来是这样的: 在此处输入图像描述

更新代码:

def estimate_sigma(hist):
    bin_edges = np.arange(len(hist))
    bin_centres = bin_edges + 0.5

    # Define model function to be used to fit to the data above:
    def gauss(x, *gparams):
        g_count = len(gparams)/3
        def gauss_impl(x, A, mu, sigma):
            return A*numpy.exp(-(x-mu)**2/(2.*sigma**2))
        res = np.zeros(len(x))
        for gi in range(g_count):
            res += gauss_impl(x, gparams[gi*3], gparams[gi*3+1], gparams[gi*3+2])
        return res

    # p0 is the initial guess for the fitting coefficients (A, mu and sigma above)
    curves_count = 4
    p0 = np.tile([1., 0., 1.], curves_count)

    coeff, var_matrix = curve_fit(gauss, bin_centres, hist, p0=p0)

    # Get the fitted curve
    hist_fit = gauss(bin_centres, *coeff)

    plt.plot(bin_centres, hist, label='Test data')
    plt.plot(bin_centres, hist_fit, label='Fitted data')

    # Finally, lets get the fitting parameters, i.e. the mean and standard deviation:
    print coeff

    plt.show()
于 2016-12-14T18:58:17.600 回答