1

我试图在这里创建一个简单的矩阵,对我批次中的每个样本重复。

这是矩阵:

balanceMatrix = np.array([[[5,10,10],[1,1,1],[1,1,1]]])
print(balanceMatrix.shape)

balanceMatrix = K.constant(balanceMatrix)
print(K.shape(balanceMatrix).eval())

到目前为止,一切都很好,我有预期的矩阵形状(1,3,3)。现在我希望对批次中的每个样本(比如 60000 个样本)重复它。从 keras文档中,我应该做的就是:

balanceMatrix = K.repeat_elements(balanceMatrix, 60000,axis=0)
print(K.shape(balanceMatrix).eval())

但这会引发以下错误,我不能简单地理解:

IndexError                                Traceback (most recent call last)
<ipython-input-28-4356baf13de8> in <module>()
     20 balanceMatrix = K.constant(balanceMatrix)
     21 print(K.shape(balanceMatrix).eval())
---> 22 balanceMatrix = K.repeat_elements(balanceMatrix, 60000,axis=0)
     23 print(K.shape(balanceMatrix).eval())
     24 

c:\users\ut65\appdata\local\programs\python\python35\lib\site-packages\keras\backend\theano_backend.py in repeat_elements(x, rep, axis)
    743     if hasattr(x, '_keras_shape'):
    744         y._keras_shape = list(x._keras_shape)
--> 745         repeat_dim = x._keras_shape[axis]
    746         if repeat_dim is not None:
    747                 y._keras_shape[axis] = repeat_dim * rep

IndexError: tuple index out of range

到底是怎么回事??我知道,我可以np.repeat(balanceMatrix,60000,axis=0)先创建 keras 张量,然后再创建 keras 张量,但 keras 选项不应该也能正常工作吗?

4

1 回答 1

2

我相信K.variable在这里会有所帮助:

balanceMatrix = K.variable(value=balanceMatrix)
于 2017-05-09T16:14:41.583 回答