4

这个问题与 Overriding other __rmul__ with your class's __mul__中的问题很接近,但我的印象是,这是一个更普遍的问题,而不是数字数据。也没有回答,我真的不想使用矩阵乘法@进行此操作。因此,问题。

我确实有一个接受标量和数值数组乘法的对象。像往常一样,左乘法工作正常,因为它使用了myobj()方法,但在右乘法中,NumPy 使用广播规则并给出元素级结果dtype=object

这也具有无法检查数组大小是否兼容的副作用。

因此,问题是

有没有办法强制 numpy 数组查找__rmul__()另一个对象的,而不是广播和执行元素__mul__()

在我的特定情况下,对象是 MIMO(多输入、多输出)传递函数矩阵(或滤波器系数矩阵,如果您愿意),因此矩阵乘法在线性系统的加法和乘法方面具有特殊含义。因此,在每个条目中都有 SISO 系统。

import numpy as np

class myobj():
    def __init__(self):
        pass

    def __mul__(self, other):
        if isinstance(other, type(np.array([0.]))):
            if other.size == 1:
                print('Scalar multiplication')
            else:
                print('Multiplication of arrays')

    def __rmul__(self, other):
        if isinstance(other, type(np.array([0.]))):
            if other.size == 1:
                print('Scalar multiplication')
            else:
                print('Multiplication of arrays')

A = myobj()
a = np.array([[[1+1j]]])  # some generic scalar
B = np.random.rand(3, 3)

使用这些定义,以下命令会显示不希望的行为。

In [123]: A*a
Scalar multiplication

In [124]: a*A
Out[124]: array([[[None]]], dtype=object)

In [125]: B*A
Out[125]: 
array([[None, None, None],
       [None, None, None],
       [None, None, None]], dtype=object)

In [126]: A*B
Multiplication of arrays

In [127]: 5 * A

In [128]: A.__rmul__(B)  # This is the desired behavior for B*A
Multiplication of arrays
4

2 回答 2

2

默认情况下,NumPy 假定未知对象(不是从 ndarray 继承)是标量,它需要对任何 NumPy 数组的每个元素进行“矢量化”乘法运算。

要自己控制操作,您需要设置__array_priority__(大多数向后兼容)或__array_ufunc__(仅限 NumPy 1.13+)。例如:

class myworkingobj(myobj):
    __array_priority__ = 1000

A = myworkingobj()
B = np.random.rand(3, 3)
B * A  # Multiplication of arrays
于 2017-06-19T15:36:00.293 回答
1

我将尝试演示发生了什么。

In [494]: B=np.random.rand(3,3)

准系统类:

In [497]: class myobj():
     ...:     pass
     ...: 
In [498]: B*myobj()
...

TypeError: unsupported operand type(s) for *: 'float' and 'myobj'

添加一个__mul__

In [500]: class myobj():
     ...:     pass
     ...:     def __mul__(self,other):
     ...:         print('myobj mul')
     ...:         return 12.3
     ...: 
In [501]: B*myobj()
...
TypeError: unsupported operand type(s) for *: 'float' and 'myobj'
In [502]: myobj()*B
myobj mul
Out[502]: 12.3

添加一个rmul

In [515]: class myobj():
     ...:     pass
     ...:     def __mul__(self,other):
     ...:         print('myobj mul',other)
     ...:         return 12.3
     ...:     def __rmul__(self,other):
     ...:         print('myobj rmul',other)
     ...:         return 4.32
     ...: 
In [516]: B*myobj()
myobj rmul 0.792751549595306
myobj rmul 0.5668783619454384
myobj rmul 0.2196204913660168
myobj rmul 0.5474970289273348
myobj rmul 0.2079367474424587
myobj rmul 0.5374571198848628
myobj rmul 0.35748803226628456
myobj rmul 0.41306113085906715
myobj rmul 0.499598995529441
Out[516]: 
array([[4.32, 4.32, 4.32],
       [4.32, 4.32, 4.32],
       [4.32, 4.32, 4.32]], dtype=object)

B*myobj()被赋予B, as B.__mul__(myobj()), 它继续对myobj().__rmul__(i)的每个元素执行B.

myobj()*B翻译为myobj.__mul__(B)

In [517]: myobj()*B
myobj mul [[ 0.79275155  0.56687836  0.21962049]
 [ 0.54749703  0.20793675  0.53745712]
 [ 0.35748803  0.41306113  0.499599  ]]
Out[517]: 12.3

In [518]: myobj().__rmul__(B)
myobj rmul [[ 0.79275155  0.56687836  0.21962049]
 [ 0.54749703  0.20793675  0.53745712]
 [ 0.35748803  0.41306113  0.499599  ]]
Out[518]: 4.32

你不能做任何事情myobj来覆盖 to 的B*myobj()翻译B.__mul__(myobj())。如果您需要更好地控制操作,请使用函数或方法。很难与口译员抗争。

于 2016-11-19T16:37:13.830 回答