问题标签 [dropout]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
2 回答
2186 浏览

python - Tensorflow:tf.nn.dropout output_keep_prob 实际上是什么?

我试图理解的概念output_keep_prob

因此,如果我的示例是简单的 RNN:

我的困惑是,如果我给出output_keep_prob=0.5的实际上是什么意思?我知道通过添加 dropout 可以减少过度拟合(称为正则化)的可能性。它在训练期间随机关闭神经元的激活,好的,我明白了这一点,但是当我给出时我很困惑

output_keep_prob=0.5我的 no_of_nodes = 500 然后 0.5 意味着它将在每次迭代中随机转动 50% 的节点,或者意味着它将只保留那些概率大于或等于 0.5 的连接

或者

我试图通过这个stackoverflow 答案来理解这个概念,但实际上 0.5 的含义也有同样的困惑?它应该在每次迭代中丢弃 50% 的节点,或者只保留那些概率大于或等于 0.5 的节点

如果答案是第二个keep only those nodes which have probability more or equal to 0.5

那么这意味着假设我给了 500 个节点单元,并且只有 30 个节点有 0.5 个概率,所以它将关闭其余 470 个节点,并且只使用 30 个节点进行传入和传出连接?

因为 这个答案说:

假设层中有 10 个单位并将 keep_prob 设置为 0.1,那么 10 个中随机选择的 9 个单位的激活将设置为 0,其余的将按 10 倍缩放。我认为更精确描述是你只保留了 10% 的节点的激活。

而另一方面,@mrry 的回答说:

这意味着层之间的每个连接(在这种情况下是在最后一个密集连接层和读出层之间)在训练时将仅以 0.5 的概率使用。

任何人都可以清楚地解释哪个是正确的以及这个值在keep_prob中实际代表什么?

0 投票
1 回答
417 浏览

python - Tensorflow:每个图像的验证预测结果相同

我有以下问题。

我正在尝试在 tensorflow 中训练 3d CNN。我将数据分为三个数据集,训练、验证和测试。

主要问题是,当我在训练 5 个 epoch 后测试验证集时,模型的输出与 5 张图像几乎相同。(这是没有任何softmax的最后一层的输出)

但是,如果我对训练集做同样的事情,我会得到一个传统的预测。

我已经全面检查了数据集,两者都是正确的并且在相同的条件下。

这是我用来构建模型和进行训练的模式:

Cnn3DMRI 类(对象):

我已经尝试过 tf.set_random_seed( 1 ) 但没有看到更正

请问有人知道吗?

非常感谢

22/04/18 编辑:

要分类的数据是二分类问题中 150x150x40 像素的 3d 图像。我总共有 400 张图片,大约是每个班级的一半。我在训练(75%)、验证(10%)和测试(15%)中分离了数据集

Edit2:我简化了一点我的模型。抬头看

还要提到我们只有 2 个班级

我尝试了另一项检查,我只用 20 张图像训练了我的模型。查看是否获得了 0 成本。

125 个 epoch 后的结果:

2018-04-24 23:58:24.992421 epoch loss mean: [4549.9554141853, 1854.6537470817566, 817.4076923541704, 686.8368729054928, 687.7348744268759, 704.946801304817, 483.6952783479355, 260.2293045549304, 272.66821688037817, 116.57515235748815, 97.86094704543848, 90.43152131629176, 132.54018089070996, 69.62595339218387, 57.412255316681694, 79.66184640157735, 70.99515068903565, 55.75798599421978 , 44.14403077028692, 38.901107819750905, 49.75594720244408, 52.6321079954505, 37.70595762133598, 42.07099115010351, 29.01994925737381, 28.365123450756073, 31.93120799213648, 43.9855432882905, 33.242121398448944, 36.57513061538339, 28.828659534454346, 29.847569406032562, 24.078316539525986, 31.630925316363573, 30.5430103354156, 26.18060240149498, 32.86780231446028, 25.42889341711998, 29.355055704712868, 26.269534677267075, 24.921810917556286, 27.15281054377556, 27 .343381822109222, 24.293660208582878, 28.212179094552994, 25.07626649737358, 21.650991335511208, 25.7527906447649, 23.42476052045822, 28.350880563259125, 22.57907184958458, 21.601420983672142, 25.28128480911255, 25.550641894340515, 22.444457232952118, 27.660063683986664, 21.863914296031, 25.722180172801018, 24.00674758851528, 21.46472266316414, 26.599679857492447, 23.52132275700569, 26.1786640137434, 24.842691332101822, 25.263965144753456, 22.730938494205475, 22.787407517433167, 23.58866274356842, 25.351682364940643, 23.85272353887558, 23.884423837065697, 24.685379207134247, 22.55106496810913, 25.993630707263947, 21.967322662472725, 22.651918083429337, 21.91003155708313, 23.782021015882492, 21.567724645137787, 22.130879193544388, 21.33636975288391, 25.624440014362335, 23.26347705721855, 22.370914071798325, 22.614411562681198, 24.962509214878082, 22.121410965919495, 20.644148647785187, 24.472172617912292, 21.622991144657135, 21.719978988170624, 21.72349101305008, 21.729621797800064, 22.090826153755188, 21.44688707590103, 22.34817299246788, 22.93226248025894, 22.63547444343567, 22.1306095123291, 22.16277289390564, 22.83771103620529, 24.171751350164413, 22.025538682937622, 21.339059710502625, 22.169043481349945, 24.614955246448517, 22.83159503340721, 21.43451902270317, 21.54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623, 23.04211688041687, 24.00969059765339, 21.8588485121727, 22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch: 125622991144657135, 21.719978988170624, 21.72349101305008, 21.729621797800064, 22.090826153755188, 21.44688707590103, 22.34817299246788, 22.93226248025894, 22.63547444343567, 22.1306095123291, 22.16277289390564, 22.83771103620529, 24.171751350164413, 22.025538682937622, 21.339059710502625, 22.169043481349945, 24.614955246448517, 22.83159503340721, 21.43451902270317, 21.54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623, 23.04211688041687, 24.00969059765339, 21.8588485121727, 22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch: 125622991144657135, 21.719978988170624, 21.72349101305008, 21.729621797800064, 22.090826153755188, 21.44688707590103, 22.34817299246788, 22.93226248025894, 22.63547444343567, 22.1306095123291, 22.16277289390564, 22.83771103620529, 24.171751350164413, 22.025538682937622, 21.339059710502625, 22.169043481349945, 24.614955246448517, 22.83159503340721, 21.43451902270317, 21.54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623, 23.04211688041687, 24.00969059765339, 21.8588485121727, 22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch: 12534817299246788, 22.93226248025894, 22.63547444343567, 22.1306095123291, 22.16277289390564, 22.83771103620529, 24.171751350164413, 22.025538682937622, 21.339059710502625, 22.169043481349945, 24.614955246448517, 22.83159503340721, 21.43451902270317, 21.54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623、23.04211688041687、24.00969059765339、21.8588485121727、22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch34817299246788, 22.93226248025894, 22.63547444343567, 22.1306095123291, 22.16277289390564, 22.83771103620529, 24.171751350164413, 22.025538682937622, 21.339059710502625, 22.169043481349945, 24.614955246448517, 22.83159503340721, 21.43451902270317, 21.54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623、23.04211688041687、24.00969059765339、21.8588485121727、22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch54544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623, 23.04211688041687, 24.00969059765339, 21.8588485121727, 22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch: 12554544973373413, 22.889380514621735, 24.168621599674225, 21.947510302066803, 22.30243694782257, 22.381454586982727, 22.50485634803772, 22.61657750606537, 22.288170099258423, 21.30070123076439, 22.489792048931122, 21.885000944137573, 21.343613982200623, 23.04211688041687, 24.00969059765339, 21.8588485121727, 22.199619591236115] 2018-04-24 23:58:24.992694 n_epoch: 125

每层的打印输出:

conv1:[[[[[0.0981627107 0.100793235 0.0934509188]]]]...]

最大值1:[[[[[0.102978 0.107030481 0.0977560952]]]]]...]

max2: [[[[[0 0 0.00116439909]]]]...]

重塑:[[0 0 0.00116439909]...]

fc:[[0.01167579 0.182256863 0.107154548]...]

fc2:[[0.773868561 0.364259362 0]...]

输出:[[0.16590938 -0.255491495][0.16590938]...]

conv1:[[[[[0.0981602222 0.100800745 0.0934513509]]]]...]

最大值1:[[[[[0.102975294 0.107038349 0.0977560282]]]]...]

max2: [[[[[0 0 0.000874094665]]]]...]

重塑:[[0 0 0.000874094665]...]

fc:[[0.0117974132 0.182980478 0.106876813]...]

fc2:[[0.774896204 0.36372292 0]...]

输出:[[0.129838273 -0.210624188][0.129838273]...]

不应该是 125 个 epoch 就可以过拟合 60 个样本吗?

对正在发生的事情有任何想法吗?

0 投票
1 回答
4365 浏览

deep-learning - Keras LSTM:辍学与经常性辍学

我意识到这篇文章提出了与此类似的问题

但我只是想要一些澄清,最好是指向某种说明差异的 Keras 文档的链接。

在我看来,dropout在神经元之间起作用。并recurrent_dropout在时间步长之间工作每个神经元。但是,我对此毫无根据。

Keras 网站上的文档根本没有帮助。

0 投票
0 回答
624 浏览

tensorflow - 将 DropoutWrapper 和 ResidualWrapper 与variational_recurrent=True 结合使用

我正在尝试创建一个用 DropoutWrapper 和 ResidualWrapper 包裹的 LSTM 单元的 MultiRNNCell。对于使用variational_recurrent=True,我们必须向DropoutWrapper 提供input_size 参数。我无法弄清楚应该将什么 input_size 传递给每个 LSTM 层,因为 ResidualWrapper 还添加了跳过连接以增加每一层的输入。

我正在使用以下实用程序函数来创建一个 LSTM 层:

以下代码用于创建完整的单元格:

对于第一层和后续 LSTM 层,应该将哪些值传递给 input_size?

0 投票
1 回答
625 浏览

tensorflow - multiRNNCell 中哪种正则化使用 L2 正则化或 dropout?

我一直在从事与用于时间序列预测的序列到序列自动编码器相关的项目。所以,我tf.contrib.rnn.MultiRNNCell在编码器和解码器中使用过。我很困惑使用哪种策略来规范我的 seq2seq 模型。我应该在损失中使用 L2 正则化还是tf.contrib.rnn.DropoutWrapper在 multiRNNCell 中使用 DropOutWrapper ()?或者我可以同时使用这两种策略...... L2 用于权重和偏差(投影层)以及 multiRNNCell 中的单元格之间的 DropOutWrapper 吗?提前致谢 :)

0 投票
1 回答
1113 浏览

r - 使用 R 的 Keras 中的蒙特卡洛(MC)辍学

如何按照 YARIN GAL 的建议在卷积神经网络中使用 Keras 实现 Monte Carlo dropout 以估计预测不确定性?我正在使用 R。R -Code 在这里

我正在小批量拟合模型,并希望使用 Monte Carlo dropout 小批量评估模型。在 Keras 文档中找不到任何提示。顺便说一句,我使用 flag training=TRUE 训练了我的模型。

谢谢

0 投票
1 回答
3918 浏览

pytorch - 在 pytorch 中实现 word dropout

我想在我的网络中添加单词丢失,以便我可以有足够的训练示例来训练“unk”令牌的嵌入。据我所知,这是标准做法。假设 unk 标记的索引为 0,填充的索引为 1(如果更方便,我们可以切换它们)。

这是一个简单的 CNN 网络,它以我期望的方式实现 word dropout:

不要介意填充 - pytorch 没有在 CNN 中使用非零填充的简单方法,更不用说可训练的非零填充,所以我手动进行。辍学也不允许我使用非零辍学,我想将填充令牌与 unk 令牌分开。我将它保留在我的示例中,因为它是这个问题存在的原因。

这不起作用,因为 dropout 需要浮点张量以便它可以正确缩放它们,而我的输入是不需要缩放的长张量。

在 pytorch 中是否有一种简单的方法可以做到这一点?我本质上想使用对 LongTensor 友好的 dropout(奖励:如果它能让我指定一个不为 0 的 dropout 常数更好,这样我就可以使用零填充)。

0 投票
1 回答
213 浏览

neural-network - 瓶颈层的辍学率

通常使用默认的 dropout rate0.5作为默认值,我也在我的全连接网络中使用它。该建议遵循原始 Dropout 论文(Hinton 等人)的建议。

我的网络由大小完全连接的层组成

[1000, 500, 100, 10, 100, 500, 1000, 20].

我不对最后一层应用 dropout。但我确实将它应用于大小 10 的瓶颈层。这似乎不合理dropout = 0.5。我想很多信息都会丢失。使用 dropout 时如何处理瓶颈层是否有经验法则?是增加瓶颈的大小还是降低辍学率更好?

0 投票
2 回答
17370 浏览

machine-learning - 如何理解 SpatialDropout1D 以及何时使用它?

偶尔我会看到一些模型正在使用SpatialDropout1D而不是Dropout. 例如,在词性标注神经网络中,他们使用:

根据 Keras 的文档,它说:

此版本执行与 Dropout 相同的功能,但它会丢弃整个 1D 特征图而不是单个元素。

但是,我无法理解entrie 1D feature的含义。更具体地说,我无法在quoraSpatialDropout1D中解释的同一模型中进行可视化。有人可以使用与 quora 中相同的模型来解释这个概念吗?

另外,在什么情况下我们将使用SpatialDropout1D而不是Dropout

0 投票
1 回答
3095 浏览

machine-learning - 关于在 RNN (Keras) 中正确使用 dropout

我对如何在 keras 中正确使用带有 RNN 的 dropout 感到困惑,特别是对于 GRU 单元。keras 文档参考了这篇论文(https://arxiv.org/abs/1512.05287),我知道所有时间步都应该使用相同的 dropout 掩码。这是通过 dropout 参数实现的,同时指定 GRU 层本身。我不明白的是:

  1. 为什么互联网上有几个示例,包括 keras 自己的示例 ( https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py ) 和 Andrew Ng 的 Coursera Seq 中的“触发词检测”作业。模型课程,他们在其中显式添加了一个 dropout 层“model.add(Dropout(0.5))”,据我了解,它将为每个时间步添加一个不同的掩码。

  2. 上面提到的论文表明这样做是不合适的,并且由于这种丢失噪声在所有时间步长上的累积,我们可能会丢失信号以及长期记忆。但是,这些模型(在每个时间步使用不同的 dropout 掩码)如何能够很好地学习和执行。

我自己已经训练了一个模型,它在每个时间步都使用不同的 dropout 掩码,虽然我没有得到我想要的结果,但该模型能够过度拟合训练数据。据我了解,这会使所有时间步长上的“噪声累积”和“信号丢失”无效(我有 1000 个时间步长序列输入到 GRU 层)。

对这种情况的任何见解、解释或经验都会有所帮助。谢谢。

更新:

为了更清楚起见,我将提到 Dropout Layer 的 keras 文档的摘录(“noise_shape:1D 整数张量,表示将与输入相乘的二进制 dropout 掩码的形状。例如,如果您的输入具有形状(batch_size , timesteps, features) 并且您希望所有时间步的 dropout 掩码都相同,您可以使用 noise_shape=(batch_size, 1, features")。所以,我相信,可以看出,当明确使用 Dropout 层并且需要在每个时间步都使用相同的掩码(如论文中所述),我们需要编辑这个 noise_shape 参数,这在我之前链接的示例中没有完成。