0

我正在研究GANs(而且我是python的初学者),我在之前的练习中发现了这部分代码我不明白。具体来说,我不明白为什么使用第 9 行的布尔值(Xk = X[Y == k]),原因我写在下面

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  1. 该布尔值根据 Y == k 的真实性返回 0 或 1,因此 Xk 始终是 X 列表的第一个或第二个值。Y 没有找到它的用处。
  2. 在第 10 行中,len(Xk) 始终为 1,为什么它使用该参数而不是单个 1?
  3. 下一行的均值和协方差每次只计算一个值。

我觉得我没有理解一些非常基本的东西。

4

2 回答 2

2

您应该考虑到它们X, Y, k是 NumPy 数组,而不是标量,并且某些运算符对它们来说是重载的。特别是==基于布尔的索引。==将是逐元素比较,而不是整个数组比较。

看看它怎么运作:

In [9]: Y = np.array([0,1,2])                                                                                        
In [10]: k = np.array([0,1,3])                                                                                       
In [11]: Y==k                                                                                                        

Out[11]: array([ True,  True, False])

所以,结果==是一个布尔数组。

In [12]: X=np.array([0,2,4])                                                                                         
In [13]: X[Y==k]                                                                                                     

Out[13]: array([0, 2])

结果是一个数组,其中的元素从X条件为True

因此len(Xk)将是 和 之间匹配元素的X数量k

于 2019-01-24T08:33:54.903 回答
0

谢谢,阿尔特姆,

你说的对。我在另一个频道找到了另一个答案,这里是:

它是一个 Numpy 数组 - 它是 NumPy 数组的一个特殊功能,称为布尔索引,可让您仅过滤掉数组中过滤器返回 True 的值:

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html?fbclid=IwAR3sGlgSwhv3i7IETsIxp4ROu9oZvNaaaBxZS01DrM5ShjWWRz22ShP2rIg#boolean-or-mask-index-arrays

将 numpy 导入为 np

a = np.array([1, 2, 3, 4, 5]) 过滤器 = a > 3

打印(过滤器)

[假,假,假,真,真]

打印(一个[过滤器])

[4, 5]

于 2019-01-24T08:56:55.427 回答