我希望能够根据给定批次中发生的实例计算条件损失(两种不同的损失)。我正在从头开始编写自定义 train_step,因为我相信这提供了实现我所想的灵活性。但是,我有点纠结于如何实现这一点。
在每个训练步骤,我都在计算批次中每个实例的真实标签和预测标签之间的分类(分类交叉熵)损失,这是标准的。此外,我还包括了一个正则化损失,该损失不是针对批次中的每个实例计算的,而只是实例的一个子集。这就是为什么我提到一个有条件的损失或两个损失目标。
在训练之前,我已经指定了训练实例 id 的列表(每个训练实例都有一个唯一的 id)。每当这些实例中的任何一个碰巧在当前批次中时,我都会仅使用这些特定实例来计算正则化项。如果这些情况都没有发生,我只计算标准分类损失。正则化项的目标是鼓励特定训练实例(由实例 id 指定)和一组附加实例(现在我们可以假设单个实例)之间的特征相似性,以平方距离衡量。
这是我到目前为止所拥有的。这不是一个有效的实现,但希望能展示我所描述的以及我希望实现的目标。模型接受图像张量并输出特征表示(用于正则化项)和预测向量(用于分类损失)。随意忽略我正在使用的方法并建议替代方法。例如,改为创建自定义损失函数或使用tf.cond
可能会有所帮助。注意:我正在使用 tensorflow 2/ tf.GradientTape()
。
class MNIST_Classifier(tf.keras.Model):
def __init__(self, model, train_sub_ids, reg_example, lmbda, **kwargs):
super(MNIST_Classifier, self).__init__(**kwargs)
self.model = model
self.train_sub_ids = train_sub_ids
self.reg_example = reg_examples
self.lmbda = lmbda
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.classification_loss_tracker = keras.metrics.Mean(
name="classification_loss"
)
self.reg_loss_tracker = keras.metrics.Mean(name="reg_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.classification_loss_tracker,
self.reg_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
ids, x, y = data # batch includes id, image, label
_, y_pred = self.model(x) # get predictions, features don't matter for classification loss
# compute classification loss for all instances
classification_loss = tf.reduce_mean(
tf.keras.losses.categorical_crossentropy(y, y_pred)
)
# compute reg loss for subset of instances (could be none)
# step 1: obtain instances from batch where id is in self.train_sub_id
# TODO: this won't work because it's not using tensor operations...need to replace
x_sub = [img for id, img in zip(ids,x) if any(id==i for i in self.train_sub_id)]
if x_sub:
features_sub, _ = self.model(x_sub)
# step 2: compute features and predictions for reg example
features_reg, _ = self.model(x_reg)
# should still work if features_sub and features_reg are different shapes in batch (left most) dim
reg_loss = tf.reduce_mean(tf.math.squared_difference(features_sub, features_reg))
else:
reg_loss = 0
total_loss = classification_loss + self.lmbda*reg_loss
variables = self.trainable_weights
grads = tape.gradient(total_loss, variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.classification_loss_tracker.update_state(classification_loss)
self.reg_loss_tracker.update_state(reg_loss)
return {
"total_loss": self.total_loss_tracker.result(),
"classification_loss": self.classification_loss_tracker.result(),
"reg_loss": self.reg_loss_tracker.result(),
}