6

我正在尝试使用行为类似于 NumPy 数组的类来实现自动微分。没有子类numpy.ndarray,但包含两个数组属性。一个用于值,一个用于雅可比矩阵。每个操作都被重载以对值和雅可比进行操作。但是,我无法让 NumPy ufunc(例如,np.log)在我的自定义“数组”上工作。

我创建了以下最小示例来说明问题。Two应该是 NumPy 数组的抗辐射版本,它计算所有内容两次,并确保结果相等。

它必须支持索引、元素对数和长度。就像一个正常的ndarray. 元素级对数在调用 using 时工作正常x.cos(),但在调用 using 时会出现意想不到的结果np.cos(x)

from __future__ import print_function
import numpy as np

class Two(object):
    def __init__(self, val1, val2):
        print("init with", val1, val2)
        assert np.array_equal(val1, val2)
        self.val1 = val1
        self.val2 = val2

    def __getitem__(self, s):
        print("getitem", s, "got", Two(self.val1[s], self.val2[s]))
        return Two(self.val1[s], self.val2[s])

    def __repr__(self):
        return "<<{}, {}>>".format(self.val1, self.val2)

    def log(self):
        print("log", self)
        return Two(np.log(self.val1), np.log(self.val2))

    def __len__(self):
        print("len", self, "=", self.val1.shape[0])
        return self.val1.shape[0]

x = Two(np.array([1,2]).T, np.array([1,2]).T)

正如预期的那样,索引从两个属性返回相关元素:

>>> print("First element in x:", x[0], "\n")
init with [1 2] [1 2]
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
First element in x: <<1, 1>> 

使用以下方法调用时,逐元素对数工作得很好x.cos()

>>> print("--- x.log() ---", x.log(), "\n")
log <<[1 2], [1 2]>>
init with [ 0.  0.69314] [ 0.  0.69314]
--- x.log() --- <<[ 0.  0.69314], [ 0.   0.69314]>> 

但是,np.log(x)没有按预期工作。它意识到对象是有长度的,所以它提取每一项并对每一项取对数,然后返回一个由两个对象组成的数组(dtype=object)。

>>> print("--- np.log(x) with len ---", np.log(x), "\n") # WTF
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
log <<1, 1>>
init with 0.0 0.0
log <<2, 2>>
init with 0.693147 0.693147
--- np.log(x) with len --- [<<0.0, 0.0>> <<0.693147, 0.693147>>]

如果Two没有长度方法,它工作得很好:

>>> del Two.__len__
>>> print("--- np.log(x) without len ---", np.log(x), "\n")
log <<[1 2], [1 2]>>
init with [ 0.          0.69314718] [ 0.   0.693147]
--- np.log(x) without len --- <<[ 0.   0.693147], [ 0.          0.693147]>>

如何创建满足要求的类(getitem、log、len)?我研究了 subclassing ndarray,但这似乎比它的价值更复杂。

另外,我在 NumPy 源代码中找不到 x.__len__访问的位置,所以我也对此感兴趣。

编辑:我将 miniconda2 与 Python 2.7.11 和 NumPy 1.11.0 一起使用。

4

0 回答 0