我正在尝试使用 deeplearning4j 训练 Word2Vec 模型。
一切都很好,直到发生很多错误:
o.d.p.Parallelization - Error occurred processing data
java.lang.IllegalArgumentException: Unable to get linear index >= 10
at org.nd4j.linalg.api.ndarray.BaseNDArray.getDouble(BaseNDArray.java:3287) ~[nd4j-api-0.0.3.5.5.5.jar:na]
at org.deeplearning4j.models.word2vec.VocabWord.getGradient(VocabWord.java:137) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.iterateSample(InMemoryLookupTable.java:277) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.word2vec.Word2Vec.iterate(Word2Vec.java:343) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
at org.deeplearning4j.models.word2vec.Word2Vec.skipGram(Word2Vec.java:331) ~[deeplearning4j-nlp-0.0.3.3.4.alpha2.jar:na]
训练过程仍在继续,但似乎这个错误会影响结果。
错误仍然存在于版本0.0.3.3.4.alpha2
和0.4-rc1.2
这是完整的代码,在教程教程中实现
public void train() throws IOException{
SentenceIterator iter = new LineSentenceIterator(new File(trainSetFileName));
iter.setPreProcessor(new SentencePreProcessor() {
@Override
public String preProcess(String sentence) {
return sentence.toLowerCase();
}
});
final EndingPreProcessor preProcessor = new EndingPreProcessor();
TokenizerFactory tokenizer = new DefaultTokenizerFactory();
tokenizer.setTokenPreProcessor(new TokenPreProcess() {
@Override
public String preProcess(String token) {
token = token.toLowerCase();
String base = preProcessor.preProcess(token);
//base = base.replaceAll("\\d", "d");
return base;
}
});
int batchSize = 1000;
int iterations = 30;
int layerSize = 300;
Word2Vec vec = new Word2Vec.Builder()
.batchSize(batchSize) //# words per minibatch.
.sampling(1e-5) // negative sampling. drops words out
.minWordFrequency(5) //
.useAdaGrad(true) //
.layerSize(layerSize) // word feature vector size
.iterations(iterations) // # iterations to train
.learningRate(0.025) //
.minLearningRate(1e-2) // learning rate decays wrt # words. floor learning
.negativeSample(10) // sample size 10 words
.iterate(iter) //
.tokenizerFactory(tokenizer)
.saveVocab(true)
.workers(3)
.build();
vec.fit();
wordVectors = vec;
WordVectorSerializer.writeWordVectors(vec, outputFileName);
}