1

我尝试使用lambdify 加速对MutableDenseMatrix 的评估。它适用于模块“numpy”。'Numexpr' 应该更快(因为我需要评估来解决一个大的优化问题)。

我正在尝试做的一个较小的例子是

from sympy import symbols, cos, Matrix, lambdify

a11, a12, a21, a22, b11, b12, b21, b22, u = symbols("a11 a12 a21 a22 b11 b12 b21 b22 u")
A = Matrix([[a11, a12], [a21, a22]])
B = Matrix([[b11, b12], [b21, b22]])
expr = A * (B ** 2) * cos(u) + A ** (-3 / 2)
f = lambdify((A, B), expr, modules='numexpr')

它引发了错误

TypeError: numexpr cannot be used with ImmutableDenseMatrix

有没有办法为 DenseMatrices 使用lambdify?或者另一个想法如何加快评估?

提前致谢!

4

1 回答 1

1

使用 numexpr 的一种可能解决方案是自行评估每个矩阵表达式。以下代码应输出一个 python 函数,该函数使用 Numexpr 评估所有矩阵表达式。

带矩阵的 Numexpr

import numpy as np
import sympy

def lambdify_numexpr(args,expr,expr_name):
    from sympy.printing.lambdarepr import NumExprPrinter as Printer
    printer = Printer({'fully_qualified_modules': False, 'inline': True,'allow_unknown_functions': False})

    s=""
    s+="import numexpr as ne\n"
    s+="from numpy import *\n"
    s+="\n"

    #get arg_names
    arg_names=[]
    arg_names_str=""
    for i in range(len(args)):
        name=[ k for k,v in globals().items() if v is args[i]][0]
        arg_names_str+=name
        arg_names.append(name)

        if i< len(args)-1:
            arg_names_str+=","

    #Write header
    s+="def "+expr_name+"("+arg_names_str+"):\n"

    #unroll array
    for ii in range(len(args)):
        arg=args[ii]
        if arg.is_Matrix:
            for i in range(arg.shape[0]):
                for j in range(arg.shape[1]):
                    s+="    "+ str(arg[i,j])+" = " + arg_names[ii]+"["+str(i)+","+str(j)+"]\n"

    s+="    \n"
    #If the expr is a matrix
    if expr.is_Matrix:
        #write expressions
        for i in range(len(expr)):
            s+="    "+ "res_"+str(i)+" = ne."+printer.doprint(expr[i])+"\n"
            s+="    \n"

        res_counter=0
        #write array
        s+="    return concatenate(("
        for i in range(expr.shape[0]):
            s+="("
            for j in range(expr.shape[1]):
                s+="res_"+str(res_counter)+","
                res_counter+=1
            s+="),"
        s+="))\n"

    #If the expr is not a matrix
    else:
        s+="    "+ "return ne."+printer.doprint(expr)+"\n"
    return s
于 2020-09-24T17:29:10.633 回答