-1

我以本教程为例来构建我的 caffe 自定义训练功能。在第 15 节有以下代码:

def train():
    niter = 200
    test_interval = 25 
    train_loss = zeros(niter)
    test_acc = zeros(int(np.ceil(niter / test_interval)))

    ### HERE ###
    output = zeros((niter, 8, 10))
    ###      ###

在第 8 行有一个ndarray(输出),这段代码的含义是什么,它是什么意思。是什么意思(niter, 8, 10)。为什么niter,为什么是 8,为什么是 10?我应该根据自己的数据集更改此数组吗?如果是,我应该使用什么尺寸?有人可以解释一下吗?

4

2 回答 2

2

如果您仔细阅读本教程,您会发现它处理数字分类,因此有10 个类。此外,他们使用技巧将 8 个示例拼凑在一起(第 11 节,靠近In [11]:):

# 我们使用一个小技巧来平铺前八张图片

因此是8维。

第 15 节展示了一个跟踪网络进度的示例。它保存每次迭代的输出预测概率。每次迭代有10 个类乘以8个示例,并且有niter迭代要跟踪。所有这些信息都存储在 3Doutput阵列中。

于 2015-11-25T17:39:05.857 回答
1

它看起来像一个调用numpy.zeroswhereshape = (niter, 8, 10)创建一个 200 * 8 * 10 浮点 0 数组。

于 2015-11-25T17:26:10.657 回答