0

我正在尝试使用 scipy(差分进化算法)的优化模块来解决优化问题。在最简单的情况下,我想将两个函数拟合到实验值。在较低 x 值范围内使用函数 1,在较高 x 值范围内使用函数 2。两个函数之间的切换是在两个函数相交的 x 值处完成的。现在的问题是,根据使用的函数的参数,可能没有交集,无法计算残差。我使用 scipy.optimize 中的 brentq 来计算交点。如果我捕捉到一个 ValueError,这意味着没有交集,我使用惩罚方法返回一个非常高的残差。现在的问题是,在许多情况下,差分进化会陷入局部最小值。我已经尝试过人口等其他选项,但我认为主要问题是使用的惩罚方法。对于两条曲线没有交点的情况,除了罚分法还有其他方法吗?我试图创建一个简单的案例来描述问题。因为我在这个例子中只拟合直线,所以在这种情况下一切正常,但对于更复杂的函数,有时我需要数百次尝试才能获得全局最小值。

如果有任何帮助,我将不胜感激。

亲切的问候克里斯蒂安

import numpy as np
from scipy.optimize import differential_evolution, brentq
from matplotlib import pyplot as plt

def intersect_fcn(x_intersect, K):
    '''used to calculate intersection with brentq'''
    m1 = K[0]
    c1 = K[1]
    
    m2 = K[2]
    c2 = K[3]
    
    y1 = m1*x_intersect + c1
    y2 = m2*x_intersect + c2
    
    return y1 - y2


def res_fcn(K, *data):
    '''returns residuals for curve fit'''
    m1 = K[0]
    c1 = K[1]
    
    m2 = K[2]
    c2 = K[3]
    
    x, y = data
    
    try: #to find an intersection
        intersect = brentq(intersect_fcn, 0, 10, args=(K))
    except ValueError: #did not find a intersection
        # penalty strategy:
        res = 100000
        return  res
    
    # if an intersection is found calculate the residuum normally:
    y1 = m1*x + c1
    y2 = m2*x + c2

    x1_bool = (x <= intersect)
    x2_bool = (x > intersect)
    y_calc = np.empty_like(x)
    
    y_calc[x1_bool] = y1[x1_bool]
    y_calc[x2_bool] = y2[x2_bool]
    
    res = y - y_calc
    res = res**2
    res = sum(res)
    return res


# generate a sample with two straight lines:
x1 = np.linspace(0, 4, 5)
x2 = np.linspace(5, 10, 5)
x = x1
x = np.append(x, x2)

y = -0.2*x1 + 7
y = np.append(y, (x2  + 1))

# add some noise
y = y + 0.1 * np.random.ranf(y.size)
args = (x, y)

bounds = [(-10, 10), (-10, 10), (-10, 10), (-10, 10)]
Results = differential_evolution(res_fcn,
                                 bounds = bounds,
                                 args=args)
y1_res = Results.x[0]*x1 + Results.x[1]
y2_res = Results.x[2]*x2 + Results.x[3]

y_res = np.append(y1_res, y2_res)

# plot results:
fig = plt.figure(figsize=(9, 6))
left, bottom, width, height = 0.15, 0.15, 0.7, 0.7
ax = fig.add_axes((left, bottom, width, height))
ax.set_xlabel('x', labelpad=14, fontsize=16)
ax.set_ylabel('y', labelpad=14, fontsize=16)
#ax.set_title('Isothermenfeld {}'.format(material), fontsize=18, pad=16)
ax.set_title('global optimization', fontsize=18, pad=16)                     
ax.tick_params(axis='both', which='major', labelsize=12)
ax.grid(color='grey', axis='both', which='major', linestyle=':')  

ax.scatter(x,y)
ax.plot(x, y_res, color = 'k')
4

0 回答 0