0

我想训练我自己的 Drums_RNN 模型并使用它在带有 magenta-js 的网络浏览器中生成 MIDI。我使用预训练的 drums_rnn 模型进行 midi 生成和合成,但在使用我自己的训练模型时遇到了一些问题。

1)首先,我训练模型如下:

drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64]" --num_training_steps=3000

然后在 magenta-js 中使用该模型,如下所示:

  import { Player, MusicRNN } from '@magenta/music';

  const model = new MusicRNN("local_http_link");
  const player = new Player();
  console.log(player);
  await model.initialize();
  await player.initialize();

//manually generate priming sequence here
sequence = {{{pitch: 36, quantizedStartStep: 0, quantizedEndStep: 1, isDrum: true}.
                     {pitch: 36, quantizedStartStep: 4, quantizedEndStep: 5, isDrum: true},
                     {pitch: 36, quantizedStartStep: 8, quantizedEndStep: 9, isDrum: true},
                      ...
                      },
      quantizationInfo: {stepsPerQuarter: 4},
      tempos: [{time: 0, qpm: 120}],
      totalQuantizedSteps: 9}

//generate new sequence by feeding priming sequence
  const samples = await model.continueSequence(sequence, 50);
  player.resumeContext();
  await player.start(samples);

这会导致以下错误,如https://github.com/magenta/magenta-js/issues/106中所述

Unhandled Rejection (TypeError): Cannot read properties of undefined (reading 'matMul')

2)正如上面链接的问题中所建议的,我通过包含“--hparams=attn_length=0”解决了这个问题

所以新的 train cmd 是:

drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --num_training_steps=3000

我使用相同的 magenta-js 代码生成新的 MIDI,现在得到以下错误:

Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match.

我还在这个问题的底部添加了整个错误跟踪。我感觉我的启动序列的输入维度与训练网络的维度不兼容。但是,我不知道如何解决它。

额外的信息:

  1. 我可以使用带有以下 cmd 的 CLI 使用此模型生成新的 midi 序列:

    drums_rnn_generate --config="drum_kit" --run_dir="model_path" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --output_dir="path_2"

model.continueSequence()但是,与from magenta-js不同,此 cmd 不需要启动序列。

  1. 当我使用预训练的 DrumsRNN 模型在 magenta-js 中生成 midi 时,我没有任何错误。

完整的错误跟踪:

Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match.
Module.assert
src/util_base.ts:108
  105 |     assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
  106 | }
  107 | // NOTE: We explicitly type out what T extends instead of any so that
> 108 | // util.flatten on a nested array of number doesn't try to infer T as a
  109 | // number[][], causing us to explicitly type util.flatten<number>().
  110 | /**
  111 |  *  Flattens an arbitrarily nested array.
View compiled
batchMatMul [as kernelFunc]
src/kernels/BatchMatMul.ts:64
  61 |     [a3dStrides[0], 1, a3dStrides[1]] :
  62 |     [a3dStrides[0], a3dStrides[1], 1];
  63 | const [bInnerStep, bOuterStep, bBatch] = transposeB ?
> 64 |     [1, b3dStrides[1], b3dStrides[0]] :
     | ^  65 |     [b3dStrides[1], 1, b3dStrides[0]];
  66 | const size = leftDim * rightDim;
  67 | const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
View compiled
kernelFunc
src/engine.ts:598
  595 | }
  596 | const dataId = backend.write(backendVals, shape, dtype);
  597 | const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
> 598 | this.incRef(t, backend);
      | ^  599 | // Count bytes for string tensors.
  600 | if (dtype === 'string') {
  601 |     const info = this.state.tensorInfo.get(dataId);
View compiled
(anonymous function)
src/engine.ts:668
  665 | }
  666 | this.state.numTensors--;
  667 | if (a.dtype === 'string') {
> 668 |     this.state.numStringTensors--;
      | ^  669 | }
  670 | const info = this.state.tensorInfo.get(a.dataId);
  671 | const refCount = info.refCount;
View compiled
Engine.scopedRun
src/engine.ts:453
  450 | if (this.shouldCheckForMemLeaks()) {
  451 |     this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
  452 | }
> 453 | const outTensors = outInfos.map((outInfo) => {
      | ^  454 |     // todo (yassogba) remove this option (Tensor) when node backend
  455 |     // methods have been modularized and they all return tensorInfo.
  456 |     // TensorInfos do not have a rank attribute.
View compiled
Engine.runKernelFunc
src/engine.ts:665
  662 | disposeTensor(a) {
  663 |     if (!this.state.tensorInfo.has(a.dataId)) {
  664 |         return;
> 665 |     }
      | ^  666 |     this.state.numTensors--;
  667 |     if (a.dtype === 'string') {
  668 |         this.state.numStringTensors--;
View compiled
Engine.runKernel
src/engine.ts:522
  519 |         outputs = kernelProfile.outputs;
  520 |     }
  521 | });
> 522 | if (isTapeOn) {
      | ^  523 |     this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
  524 | }
  525 | if (this.state.profiling) {
View compiled
matMul_
src/ops/mat_mul.ts:54
matMul__op
src/ops/operation.ts:51
  48 |     ENGINE.endScope(result);
  49 |     return result;
  50 | }
> 51 | catch (ex) {
     | ^  52 |     ENGINE.endScope(null);
  53 |     throw ex;
  54 | }
View compiled
basicLSTMCell_
src/ops/basic_lstm_cell.ts:61
  58 |     const f = slice(res, [0, sliceCols * 2], sliceSize);
  59 |     const o = slice(res, [0, sliceCols * 3], sliceSize);
  60 |     const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f))));
> 61 |     const newH = mul(tanh(newC), sigmoid(o));
  62 |     return [newC, newH];
  63 | }
  64 | export const basicLSTMCell = op({ basicLSTMCell_ });
View compiled
Module.basicLSTMCell__op
src/ops/operation.ts:51
  48 |     ENGINE.endScope(result);
  49 |     return result;
  50 | }
> 51 | catch (ex) {
     | ^  52 |     ENGINE.endScope(null);
  53 |     throw ex;
  54 | }
View compiled
Array.<anonymous>
src/music_rnn/model.ts:165
  162 | else {
  163 |     sampledOutput = logits.argMax().as1D();
  164 | }
> 165 | if (returnProbs) {
      | ^  166 |     probs.push(tf.softmax(logits));
  167 | }
  168 | nextInput =
View compiled
multiRNNCell_
src/ops/multi_rnn_cell.ts:56
Module.multiRNNCell__op
src/ops/operation.ts:51
  48 |     ENGINE.endScope(result);
  49 |     return result;
  50 | }
> 51 | catch (ex) {
     | ^  52 |     ENGINE.endScope(null);
  53 |     throw ex;
  54 | }
View compiled
MusicRNN.sampleRnn
src/music_rnn/model.ts:375
(anonymous function)
src/music_rnn/model.ts:272
(anonymous function)
src/engine.ts:442
  439 |     // backend and set properties like this.backendName
  440 |     // tslint:disable-next-line: no-unused-expression
  441 |     this.backend;
> 442 | }
      | ^  443 | const kernel = getKernel(kernelName, this.backendName);
  444 | let out;
  445 | if (kernel != null) {
View compiled
Engine.scopedRun
src/engine.ts:453
  450 | if (this.shouldCheckForMemLeaks()) {
  451 |     this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
  452 | }
> 453 | const outTensors = outInfos.map((outInfo) => {
      | ^  454 |     // todo (yassogba) remove this option (Tensor) when node backend
  455 |     // methods have been modularized and they all return tensorInfo.
  456 |     // TensorInfos do not have a rank attribute.
View compiled
Engine.tidy
src/engine.ts:440
  437 |     // can be deferred until an op/ kernel is run).
  438 |     // The below getter has side effects that will try to initialize the
  439 |     // backend and set properties like this.backendName
> 440 |     // tslint:disable-next-line: no-unused-expression
      | ^  441 |     this.backend;
  442 | }
  443 | const kernel = getKernel(kernelName, this.backendName);
View compiled
Module.tidy
src/globals.ts:192
  189 |     const tensors = getTensorsInContainer(container);
  190 |     tensors.forEach(tensor => tensor.dispose());
  191 | }
> 192 | /**
  193 |  * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
  194 |  * automatically.
  195 |  *
View compiled
MusicRNN.continueSequenceImpl
src/music_rnn/model.ts:258
MusicRNN.continueSequence
src/music_rnn/model.ts:215
playGen
src/magenta/magenta.js:137
  134 | 
  135 | 
  136 | 
> 137 | const samples = await model.continueSequence(sequence, 50);
      | ^  138 | 
  139 | player.resumeContext();
  140 | await player.start(samples);
4

0 回答 0