我正在尝试开发用于音乐生成的自动编码器;为了达到这个目的,我正在尝试开发一种捕捉音乐关系的损失函数。
我目前的想法是“螺旋”损失函数,也就是说,如果系统在不同的八度音阶中预测相同的音符,则损失应该比音符错误的情况要小。此外,接近正确音符的音符,例如 B 和 D 到 C,也应该有小的损失。人们可以从概念上将其视为找到线圈或螺旋上两点之间的距离,使得不同八度音阶中的相同音符位于与线圈相切的直线上,但相隔一定的循环距离。
我在 PyTorch 中工作,我的输入表示是 36 x 36 张量,其中行表示音符(MIDI 范围 48:84,钢琴的中间三个八度音阶),列表示时间步长(1 列 = 1/ 100 秒)。矩阵中的值为 0 或 1,表示音符在特定时间打开。
这是我目前对损失的实现:
def SpiralLoss():
def spiral_loss(input, output):
loss = Variable(torch.FloatTensor([0]))
d = 5
r = 10
for i in xrange(input.size()[0]):
for j in xrange(input.size()[3]):
# take along the 1 axis because it's a column vector
inval, inind = torch.max(input[i, :, :, j], 1)
outval, outind = torch.max(output[i, :, :, j], 1)
note_loss = (r*30*(inind%12 - outind%12)).float()
octave_loss = (d*(inind/12 - outind/12)).float()
loss += torch.sqrt(torch.pow(note_loss, 2) + torch.pow(octave_loss, 2))
return loss
return spiral_loss
这种损失的问题是最大函数不可微。我想不出一种方法来区分这种损失,并且想知道是否有人可能有任何想法或建议?
我不确定这是否适合这样的帖子,所以如果不是,我真的很感激任何指向更好位置的指针。