1

我正在研究基于三元组损失的文本嵌入模型。
简短描述:
我有一个关于网上商店的数据库,当用户在搜索栏输入文字时,我需要找到合适的产品。我想要一个模型比匹配字符串更好,并且可以理解用户的想法。我定义了一个这样的三元组网络:我的输入是(查询文本 [anchor],搜索后的下一个产品用户视图 [positive],随机产品 [negative])。我建立了一个基于bi-LSTM的编码器模型,并尝试训练anchor和positive之间的距离最小,anchor和negative之间的距离最大,并使用triplet loss。
我试图实现这个网络在这里输入图像描述
参考:https
: //arxiv.org/pdf/2104.08558.pdf 我的encoderNet

class encodeNet(nn.Module):

def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, 
             bidirectional, dropout):
    
    #Constructor
    super().__init__()          
    
    #embedding layer
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.directions = bidirectional
    #lstm layer
    self.lstm = nn.LSTM(embedding_dim, 
                       hidden_dim, 
                       num_layers=n_layers, 
                       bidirectional=bidirectional, 
                       dropout=dropout,
                       batch_first=True)
    
    self.fc1 = nn.Linear(hidden_dim * 2, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 512)
    self.dropout = nn.Dropout(p=0.3)
    self.batchnorm1 = nn.BatchNorm1d(1024)
    self.batchnorm2 = nn.BatchNorm1d(512)
    self.relu = nn.ReLU()
    self.P1 = nn.MaxPool1d(2, stride=2)
    self.act = nn.Sigmoid()
    
def LM(self, text):
    embedded = self.embedding(text)       
    packed_output, (hidden, cell) = self.lstm(embedded)
    #concat the final forward and backward hidden state
    hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
    hidden = self.dropout(hidden)
    hidden = self.fc1(hidden)
    hidden = self.batchnorm1(hidden)  
    hidden = self.relu(hidden)
    hidden = self.fc2(hidden)       
    hidden = self.batchnorm2(hidden)  
    hidden = self.fc3(hidden)
    return hidden
def forward(self, anchor, pos, neg):
    anchor = self.LM(anchor)
    pos = self.LM(pos)
    neg = self.LM(neg)
    anchor = self.P1(anchor)
    pos = self.P1(pos)
    neg = self.P1(neg)
    return anchor, pos,neg

我使用了pytorch框架的损失函数triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)。
结果,我看到在训练数据集中,损失值下降到如此之小和如此之快,但在有效数据集中,损失值没有任何意义,它像随机一样上下波动。
我用 8572 个词汇训练模型,81822 个训练样本,是不是数据集太小了?
你能帮我吗?我的解决方案有什么问题?

4

1 回答 1

0

我建议你使用 Hard-Triplet。您可以在 FaceNet 论文中了解更多信息。我希望它对你有帮助。

于 2022-01-18T18:12:04.737 回答