你好,张量流同胞!
我有以下架构:
我输入了一些连续变量(实际上是我从 google word2vec 获取的词嵌入),并且我试图预测可以被视为连续和离散的输出(对不起,数学家!但这实际上取决于一个人的训练目标)。输出取值从 0 到 1000,间隔为 0.25(或精度超参数),因此:0, 0.25, 0.50, ..., 100.0。
我知道不可能包含像 tf.to_int (如果有必要我可以省略分数)或 tf.round 之类的东西,因为它们是不可微的,所以我们不能反向传播。但是,我觉得有一些解决方案可以让网络“知道”它正在寻找四舍五入的解决方案:一些整数的小部分,比如 0.25、5.75,但我实际上什至不知道去哪里找。我查了量化,但这似乎有点矫枉过正。
所以我的问题是:
- 如何通知图表我们不接受低于 0.0 的值?在网络输出“logits”(回归预测)上做 abs 是否值得考虑?如果不是,我是否可以修改损失项以严厉惩罚低于 0 的分数并使用绝对误差而不是平方误差?我可能不知道这样做的全部后果
- 我不在乎 4.5 的预测是 4.49999 还是 4.4,因为我将预测四舍五入到最接近的 0.25 以获得准确性,这是我最终的模型评估指标。如果可以,我可以使用吗?
precision = 0.01 # so that sqrt(precision) == 0.1
loss=tf.reduce_mean(tf.max(0, tf.square(tf.sub(logits, targets)) - precision ))