1

首先,这是我的代码:

"""Softmax."""

scores = [3.0, 1.0, 0.2]

import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    num = np.exp(x)
    score_len = len(x)
    y = np.array([0]*score_len)
    sum_n = np.sum(num)
    #print sum_n
    for index in range(1,score_len):
        y[index] = (num[index])/sum_n
    return y

print(softmax(scores))

错误出现在以下行:

y[index] = (num[index])/sum_n

我运行代码:

# Plot softmax curves
import matplotlib.pyplot as plt
x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])

plt.plot(x, softmax(scores).T, linewidth=2)
plt.show()

这里到底出了什么问题?

4

4 回答 4

2

只需将print语句编辑为“调试器”即可揭示正在发生的事情:

import numpy as np

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    num = np.exp(x)
    score_len = len(x)
    y = np.array([0]*score_len)
    sum_n = np.sum(num)
    #print sum_n
    for index in range(1,score_len):
        print((num[index])/sum_n)
        y[index] = (num[index])/sum_n
    return y

x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])
softmax(scores).T

这打印

[ 0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504  0.00065504
  0.00065504  0.00065504]

所以你试图将此数组分配给另一个数组的一个元素。这是不允许的!

有几种方法可以做到这一点,以使其正常工作。只是改变

y = np.array([0]*score_len)

到多维数组将起作用:

y = np.zeros(score.shape)

那应该可以解决问题,但我不确定这是否是您的意图。


编辑:

看来您不想要多维输入,所以您只需要更改:

scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])

scores = np.hstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])

通过打印验证这些数组的形状scores.shape确实可以帮助您自己找到此类错误。第一个堆叠沿第一个轴(vstack)和第零轴堆叠(这是你想要的)

于 2016-02-29T16:43:00.457 回答
1

这是初始化数组的不好方法:

y = np.array([0]*score_len)

最好做类似的事情

y = np.zeros((n,m))

其中nm是最终产品的 2 个维度。我从你的另一个问题中假设你想y成为 2d (毕竟你做了一个.T之后)。

注意scores传递给函数的形状。并且在迭代时,包括:. 它可以是可选的,但你需要它来保持你自己的尺寸:

y[index,:] = (num[index,:])/sum_n

总之 - 专注于理解如何使用多维数组 - 如何创建它们,如何索引它们,如何在没有迭代的情况下使用它们,以及如何在需要时正确迭代。

于 2016-02-29T20:33:39.427 回答
0

这应该完美而快速

scores = [3.0, 1.0, 0.2]

import numpy as np


def softmax(x):

    num = np.exp(x)
    score_len = len(x)

    y = np.zeros(score_len, object) # or => np.asarray([None]*score_len)
    sum_n = np.sum(num)

    for i in range(score_len):
        y[i] = num[i] / sum_n

    return y


print(softmax(scores))

x = np.arange(-2.0, 6.0, 0.1)
scores = np.vstack([x, np.ones_like(x), 0.2 * np.ones_like(x)])

printout = softmax(scores).T

print(printout)

输出:

[0.8360188027814407 0.11314284146556011 0.050838355752999158]

[ array([  3.26123038e-05,   3.60421698e-05,   3.98327578e-05,
         4.40220056e-05,   4.86518403e-05,   5.37685990e-05,
         5.94234919e-05,   6.56731151e-05,   7.25800169e-05,
         8.02133239e-05,   8.86494329e-05,   9.79727751e-05,
         1.08276662e-04,   1.19664218e-04,   1.32249413e-04,
         1.46158206e-04,   1.61529798e-04,   1.78518035e-04,
         1.97292941e-04,   2.18042421e-04,   2.40974142e-04,
         2.66317614e-04,   2.94326482e-04,   3.25281069e-04,
         3.59491177e-04,   3.97299194e-04,   4.39083515e-04,
         4.85262332e-04,   5.36297817e-04,   5.92700751e-04,
         6.55035633e-04,   7.23926331e-04,   8.00062328e-04,
         8.84205618e-04,   9.77198335e-04,   1.07997118e-03,
         1.19355274e-03,   1.31907978e-03,   1.45780861e-03,
         1.61112768e-03,   1.78057146e-03,   1.96783579e-03,
         2.17479489e-03,   2.40352006e-03,   2.65630048e-03,
         2.93566604e-03,   3.24441273e-03,   3.58563059e-03,
         3.96273465e-03,   4.37949910e-03,   4.84009504e-03,
         5.34913227e-03,   5.91170543e-03,   6.53344491e-03,
         7.22057331e-03,   7.97996764e-03,   8.81922816e-03,
         9.74675448e-03,   1.07718296e-02,   1.19047128e-02,
         1.31567424e-02,   1.45404491e-02,   1.60696814e-02,
         1.77597446e-02,   1.96275532e-02,   2.16918010e-02,
         2.39731477e-02,   2.64944256e-02,   2.92808687e-02,
         3.23603645e-02,   3.57637337e-02,   3.95250385e-02,
         4.36819230e-02,   4.82759910e-02,   5.33532213e-02,
         5.89644285e-02,   6.51657716e-02,   7.20193157e-02,
         7.95936532e-02,   8.79645908e-02])
 array([ 0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504,
        0.00065504,  0.00065504,  0.00065504,  0.00065504,  0.00065504])
 array([ 0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433,
        0.00029433,  0.00029433,  0.00029433,  0.00029433,  0.00029433])]
于 2017-12-21T05:51:58.993 回答
0

数组构造中的不一致可能会导致这种问题,例如

[[1,2,3,4], [2,3], [1],[1,2,3,4]]

这是一个糟糕的示例数组。

于 2018-06-22T10:08:22.197 回答