我正在研究基于三元组损失的文本嵌入模型。
简短描述:
我有一个关于网上商店的数据库,当用户在搜索栏输入文字时,我需要找到合适的产品。我想要一个模型比匹配字符串更好,并且可以理解用户的想法。我定义了一个这样的三元组网络:我的输入是(查询文本 [anchor],搜索后的下一个产品用户视图 [positive],随机产品 [negative])。我建立了一个基于bi-LSTM的编码器模型,并尝试训练anchor和positive之间的距离最小,anchor和negative之间的距离最大,并使用triplet loss。
我试图实现这个网络在这里输入图像描述
参考:https
:
//arxiv.org/pdf/2104.08558.pdf 我的encoderNet
class encodeNet(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers,
bidirectional, dropout):
#Constructor
super().__init__()
#embedding layer
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.directions = bidirectional
#lstm layer
self.lstm = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
self.fc1 = nn.Linear(hidden_dim * 2, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 512)
self.dropout = nn.Dropout(p=0.3)
self.batchnorm1 = nn.BatchNorm1d(1024)
self.batchnorm2 = nn.BatchNorm1d(512)
self.relu = nn.ReLU()
self.P1 = nn.MaxPool1d(2, stride=2)
self.act = nn.Sigmoid()
def LM(self, text):
embedded = self.embedding(text)
packed_output, (hidden, cell) = self.lstm(embedded)
#concat the final forward and backward hidden state
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
hidden = self.dropout(hidden)
hidden = self.fc1(hidden)
hidden = self.batchnorm1(hidden)
hidden = self.relu(hidden)
hidden = self.fc2(hidden)
hidden = self.batchnorm2(hidden)
hidden = self.fc3(hidden)
return hidden
def forward(self, anchor, pos, neg):
anchor = self.LM(anchor)
pos = self.LM(pos)
neg = self.LM(neg)
anchor = self.P1(anchor)
pos = self.P1(pos)
neg = self.P1(neg)
return anchor, pos,neg
我使用了pytorch框架的损失函数triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)。
结果,我看到在训练数据集中,损失值下降到如此之小和如此之快,但在有效数据集中,损失值没有任何意义,它像随机一样上下波动。
我用 8572 个词汇训练模型,81822 个训练样本,是不是数据集太小了?
你能帮我吗?我的解决方案有什么问题?