我想训练我自己的 Drums_RNN 模型并使用它在带有 magenta-js 的网络浏览器中生成 MIDI。我使用预训练的 drums_rnn 模型进行 midi 生成和合成,但在使用我自己的训练模型时遇到了一些问题。
1)首先,我训练模型如下:
drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64]" --num_training_steps=3000
然后在 magenta-js 中使用该模型,如下所示:
import { Player, MusicRNN } from '@magenta/music';
const model = new MusicRNN("local_http_link");
const player = new Player();
console.log(player);
await model.initialize();
await player.initialize();
//manually generate priming sequence here
sequence = {{{pitch: 36, quantizedStartStep: 0, quantizedEndStep: 1, isDrum: true}.
{pitch: 36, quantizedStartStep: 4, quantizedEndStep: 5, isDrum: true},
{pitch: 36, quantizedStartStep: 8, quantizedEndStep: 9, isDrum: true},
...
},
quantizationInfo: {stepsPerQuarter: 4},
tempos: [{time: 0, qpm: 120}],
totalQuantizedSteps: 9}
//generate new sequence by feeding priming sequence
const samples = await model.continueSequence(sequence, 50);
player.resumeContext();
await player.start(samples);
这会导致以下错误,如https://github.com/magenta/magenta-js/issues/106中所述
Unhandled Rejection (TypeError): Cannot read properties of undefined (reading 'matMul')
2)正如上面链接的问题中所建议的,我通过包含“--hparams=attn_length=0”解决了这个问题
所以新的 train cmd 是:
drums_rnn_train --config="drum_kit" --run_dir="path1" --sequence_example_file="path2" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --num_training_steps=3000
我使用相同的 magenta-js 代码生成新的 MIDI,现在得到以下错误:
Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match.
我还在这个问题的底部添加了整个错误跟踪。我感觉我的启动序列的输入维度与训练网络的维度不兼容。但是,我不知道如何解决它。
额外的信息:
我可以使用带有以下 cmd 的 CLI 使用此模型生成新的 midi 序列:
drums_rnn_generate --config="drum_kit" --run_dir="model_path" --hparams="batch_size=64,rnn_layer_sizes=[64,64],attn_length=0" --output_dir="path_2"
model.continueSequence()
但是,与from magenta-js不同,此 cmd 不需要启动序列。
- 当我使用预训练的 DrumsRNN 模型在 magenta-js 中生成 midi 时,我没有任何错误。
完整的错误跟踪:
Unhandled Rejection (Error): Error in matMul: inner shapes (74) and (582) of Tensors with shapes 1,74 and 582,256 and transposeA=false and transposeB=false must match.
Module.assert
src/util_base.ts:108
105 | assert(a != null, () => `The input to the tensor constructor must be a non-null value.`);
106 | }
107 | // NOTE: We explicitly type out what T extends instead of any so that
> 108 | // util.flatten on a nested array of number doesn't try to infer T as a
109 | // number[][], causing us to explicitly type util.flatten<number>().
110 | /**
111 | * Flattens an arbitrarily nested array.
View compiled
batchMatMul [as kernelFunc]
src/kernels/BatchMatMul.ts:64
61 | [a3dStrides[0], 1, a3dStrides[1]] :
62 | [a3dStrides[0], a3dStrides[1], 1];
63 | const [bInnerStep, bOuterStep, bBatch] = transposeB ?
> 64 | [1, b3dStrides[1], b3dStrides[0]] :
| ^ 65 | [b3dStrides[1], 1, b3dStrides[0]];
66 | const size = leftDim * rightDim;
67 | const result = buffer([batchDim, leftDim, rightDim], a3d.dtype);
View compiled
kernelFunc
src/engine.ts:598
595 | }
596 | const dataId = backend.write(backendVals, shape, dtype);
597 | const t = new Tensor(shape, dtype, dataId, this.nextTensorId());
> 598 | this.incRef(t, backend);
| ^ 599 | // Count bytes for string tensors.
600 | if (dtype === 'string') {
601 | const info = this.state.tensorInfo.get(dataId);
View compiled
(anonymous function)
src/engine.ts:668
665 | }
666 | this.state.numTensors--;
667 | if (a.dtype === 'string') {
> 668 | this.state.numStringTensors--;
| ^ 669 | }
670 | const info = this.state.tensorInfo.get(a.dataId);
671 | const refCount = info.refCount;
View compiled
Engine.scopedRun
src/engine.ts:453
450 | if (this.shouldCheckForMemLeaks()) {
451 | this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
452 | }
> 453 | const outTensors = outInfos.map((outInfo) => {
| ^ 454 | // todo (yassogba) remove this option (Tensor) when node backend
455 | // methods have been modularized and they all return tensorInfo.
456 | // TensorInfos do not have a rank attribute.
View compiled
Engine.runKernelFunc
src/engine.ts:665
662 | disposeTensor(a) {
663 | if (!this.state.tensorInfo.has(a.dataId)) {
664 | return;
> 665 | }
| ^ 666 | this.state.numTensors--;
667 | if (a.dtype === 'string') {
668 | this.state.numStringTensors--;
View compiled
Engine.runKernel
src/engine.ts:522
519 | outputs = kernelProfile.outputs;
520 | }
521 | });
> 522 | if (isTapeOn) {
| ^ 523 | this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved, attrs);
524 | }
525 | if (this.state.profiling) {
View compiled
matMul_
src/ops/mat_mul.ts:54
matMul__op
src/ops/operation.ts:51
48 | ENGINE.endScope(result);
49 | return result;
50 | }
> 51 | catch (ex) {
| ^ 52 | ENGINE.endScope(null);
53 | throw ex;
54 | }
View compiled
basicLSTMCell_
src/ops/basic_lstm_cell.ts:61
58 | const f = slice(res, [0, sliceCols * 2], sliceSize);
59 | const o = slice(res, [0, sliceCols * 3], sliceSize);
60 | const newC = add(mul(sigmoid(i), tanh(j)), mul($c, sigmoid(add($forgetBias, f))));
> 61 | const newH = mul(tanh(newC), sigmoid(o));
62 | return [newC, newH];
63 | }
64 | export const basicLSTMCell = op({ basicLSTMCell_ });
View compiled
Module.basicLSTMCell__op
src/ops/operation.ts:51
48 | ENGINE.endScope(result);
49 | return result;
50 | }
> 51 | catch (ex) {
| ^ 52 | ENGINE.endScope(null);
53 | throw ex;
54 | }
View compiled
Array.<anonymous>
src/music_rnn/model.ts:165
162 | else {
163 | sampledOutput = logits.argMax().as1D();
164 | }
> 165 | if (returnProbs) {
| ^ 166 | probs.push(tf.softmax(logits));
167 | }
168 | nextInput =
View compiled
multiRNNCell_
src/ops/multi_rnn_cell.ts:56
Module.multiRNNCell__op
src/ops/operation.ts:51
48 | ENGINE.endScope(result);
49 | return result;
50 | }
> 51 | catch (ex) {
| ^ 52 | ENGINE.endScope(null);
53 | throw ex;
54 | }
View compiled
MusicRNN.sampleRnn
src/music_rnn/model.ts:375
(anonymous function)
src/music_rnn/model.ts:272
(anonymous function)
src/engine.ts:442
439 | // backend and set properties like this.backendName
440 | // tslint:disable-next-line: no-unused-expression
441 | this.backend;
> 442 | }
| ^ 443 | const kernel = getKernel(kernelName, this.backendName);
444 | let out;
445 | if (kernel != null) {
View compiled
Engine.scopedRun
src/engine.ts:453
450 | if (this.shouldCheckForMemLeaks()) {
451 | this.checkKernelForMemLeak(kernelName, numDataIdsBefore, outInfos);
452 | }
> 453 | const outTensors = outInfos.map((outInfo) => {
| ^ 454 | // todo (yassogba) remove this option (Tensor) when node backend
455 | // methods have been modularized and they all return tensorInfo.
456 | // TensorInfos do not have a rank attribute.
View compiled
Engine.tidy
src/engine.ts:440
437 | // can be deferred until an op/ kernel is run).
438 | // The below getter has side effects that will try to initialize the
439 | // backend and set properties like this.backendName
> 440 | // tslint:disable-next-line: no-unused-expression
| ^ 441 | this.backend;
442 | }
443 | const kernel = getKernel(kernelName, this.backendName);
View compiled
Module.tidy
src/globals.ts:192
189 | const tensors = getTensorsInContainer(container);
190 | tensors.forEach(tensor => tensor.dispose());
191 | }
> 192 | /**
193 | * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed
194 | * automatically.
195 | *
View compiled
MusicRNN.continueSequenceImpl
src/music_rnn/model.ts:258
MusicRNN.continueSequence
src/music_rnn/model.ts:215
playGen
src/magenta/magenta.js:137
134 |
135 |
136 |
> 137 | const samples = await model.continueSequence(sequence, 50);
| ^ 138 |
139 | player.resumeContext();
140 | await player.start(samples);