1

我正在努力scipy.minimize为一个数组的优化参数工作,我只在目标函数内查看数组的一部分。

import numpy as np
from scipy.optimize import minimize

n = 5
X_true = np.random.normal(size=(n,n))
X_guess = np.random.normal(size=(n,n))
indices = np.triu_indices(n)

def mean_square_error(X):
    return ((X.flatten() - X_true.flatten()) ** 2).mean()

def mean_square_error_over_indices(X):
    return ((X[indices].flatten() - X_true[indices].flatten()) ** 2).mean()

# works fine
print(mean_square_error(X_guess)) 

# works fine
print(mean_square_error_over_indices(X_guess)) 

# works fine (flatten is necessary inside the objective function)
print(minimize(mean_square_error, X_guess).x)

# IndexError
print(minimize(mean_square_error_over_indices, X_guess).x)

追溯:

IndexError                                Traceback (most recent call last)
<ipython-input-1-08d40604e22a> in <module>
     20 print(minimize(mean_square_error, X_guess).x) # works fine
     21 
---> 22 print(minimize(mean_square_error_over_indices, X_guess).x) # error

C:\Anaconda\lib\site-packages\scipy\optimize\_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    593         return _minimize_cg(fun, x0, args, jac, callback, **options)
    594     elif meth == 'bfgs':
--> 595         return _minimize_bfgs(fun, x0, args, jac, callback, **options)
    596     elif meth == 'newton-cg':
    597         return _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in _minimize_bfgs(fun, x0, args, jac, callback, gtol, norm, eps, maxiter, disp, return_all, **unknown_options)
    968     else:
    969         grad_calls, myfprime = wrap_function(fprime, args)
--> 970     gfk = myfprime(x0)
    971     k = 0
    972     N = len(x0)

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in function_wrapper(*wrapper_args)
    298     def function_wrapper(*wrapper_args):
    299         ncalls[0] += 1
--> 300         return function(*(wrapper_args + args))
    301 
    302     return ncalls, function_wrapper

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in approx_fprime(xk, f, epsilon, *args)
    728 
    729     """
--> 730     return _approx_fprime_helper(xk, f, epsilon, args=args)
    731 
    732 

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in _approx_fprime_helper(xk, f, epsilon, args, f0)
    662     """
    663     if f0 is None:
--> 664         f0 = f(*((xk,) + args))
    665     grad = numpy.zeros((len(xk),), float)
    666     ei = numpy.zeros((len(xk),), float)

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in function_wrapper(*wrapper_args)
    298     def function_wrapper(*wrapper_args):
    299         ncalls[0] += 1
--> 300         return function(*(wrapper_args + args))
    301 
    302     return ncalls, function_wrapper

<ipython-input-1-08d40604e22a> in mean_square_error_over_indices(X)
     11 
     12 def mean_square_error_over_indices(X):
---> 13     return ((X[indices].flatten() - X_true[indices].flatten()) ** 2).mean()
     14 
     15 

IndexError: too many indices for array
4

1 回答 1

2

根据文档 scipy.optimize.minimize接受一维数组,因此您对使用“flatten()”是正确的,但您也应该将它用于传递给 minimize()` 的初始猜测。这是我解决您问题的建议:

import numpy as np
from scipy.optimize import minimize

# init 
n       = 5
x_true  = np.random.normal(size=(n,n))
x_guess = np.random.normal(size=(n,n))
indices = np.triu_indices(n)

# flatten initial values for minimize
guess_x0          = x_guess.flatten()
guess_indeices_x0 = x_guess[indices].flatten()

# define objective funcs
mse              = lambda x: ((x - x_true.flatten()) ** 2).mean()
mse_over_indices = lambda x: ((x - x_true[indices].flatten()) ** 2).mean()

# works fine
print("MSE: %5f"             % mse(guess_x0)) 
print("MSE for indices: %5f" % mse_over_indices(guess_indeices_x0))

# works fine (flatten is necessary inside the objective function)
print("Result 1:", minimize(mse, guess_x0).x)
print("Result 2:", minimize(mse_over_indices, guess_indeices_x0).x)

输出:

MSE: 2.763674
MSE for indices: 3.192139
Result 1: [-1.2828193   0.49468516 -0.99500157 -0.47284983  1.6380719  -0.33051017
  0.13769163 -0.23920633 -0.87430572  0.63945803  1.38327467  0.8484247
  0.31888506 -1.15764468  1.06891773 -0.28372002  1.34104286  1.21024251
 -0.11020374  1.37024001  1.08940389  1.82391261  0.32469148  0.64567877
  0.54364199]
Result 2: [-1.28281964  0.49468503 -0.99500147 -0.47284976  1.63807209  0.13769154
 -0.23920624 -0.87430606  0.63945812  0.31888521 -1.15764475  1.06891776
 -0.11020373  1.37024006  0.54364213]
于 2019-06-08T15:36:01.187 回答