我对 Java 的 ML 的 WEKA API 很陌生。
由于 weka 中没有余弦相似度算法,所以想通过修改 WEKA 的 simpleKmeans 算法,将这个算法加入到 WEKA 中。
weka 中的 simpleKmeans 算法使用EuclideanDistance
并且我希望使用余弦相似度而不是 euclideanDistance。
我google了很多关于如何修改simpleKmeans算法的WEKA开源软件代码,在网上发现了这个问题(基本上是pedro的观点)
http://comments.gmane.org/gmane.comp.ai.weka/22681
这里提到的步骤是:
扩展
weka.core.EuclideanDistance
并覆盖 distance(Instance first, Instance second, PerformanceStats stats) 方法。使用
EuclideanDistance
类型将其实例化为扩展类,将实例作为扩展类构造函数的参数传递。使用传递实例
setDistanceFunction
的类中的方法。SimpleKMeans
EuclideanDistance
这是 WEKA 流程第一部分的代码。
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package weka.core;
import weka.core.Attribute;
//import weka.core.EuclideanDistance;
import java.util.Enumeration;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.core.neighboursearch.PerformanceStats;
import weka.core.TechnicalInformation.Type;
/**
*
* @author Sgr
*/
public class CosineSimilarity extends EuclideanDistance{
public Instances m_Data = null;
public String version ="1.0";
@Override
public double distance(Instance arg0, Instance arg1) {
// TODO Auto-generated method stub
return distance(arg0, arg1, Double.POSITIVE_INFINITY, null);
}
@Override
public double distance(Instance arg0, Instance arg1, PerformanceStats arg2) {
// TODO Auto-generated method stub
return distance(arg0, arg1, Double.POSITIVE_INFINITY, arg2);
}
@Override
public double distance(Instance arg0, Instance arg1, double arg2) {
// TODO Auto-generated method stub
return distance(arg0, arg1, arg2, null);
}
@Override
public double distance(Instance first, Instance second, double cutOffValue,PerformanceStats arg3) {
double distance = 0;
int firstI, secondI;
int firstNumValues = first.numValues();
int secondNumValues = second.numValues();
int numAttributes = m_Data.numAttributes();
int classIndex = m_Data.classIndex();
double normA, normB;
normA = 0;
normB = 0;
for (int p1 = 0, p2 = 0; p1 < firstNumValues || p2 < secondNumValues;) {
if (p1 >= firstNumValues)
firstI = numAttributes;
else firstI = first.index(p1);
if (p2 >= secondNumValues)
secondI = numAttributes;
else secondI = second.index(p2);
if (firstI == classIndex) {
p1++;
continue;
}
// if ((firstI < numAttributes)) {
// p1++;
// continue;
// }
if (secondI == classIndex) {
p2++;
continue;
}
// if ((secondI < numAttributes)) {
// p2++;
// continue;
// }
double diff;
if (firstI == secondI) {
diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2));
normA += Math.pow(first.valueSparse(p1), 2);
normB += Math.pow(second.valueSparse(p2), 2);
p1++;
p2++;
}
else if (firstI > secondI) {
diff = difference(secondI, 0, second.valueSparse(p2));
normB += Math.pow(second.valueSparse(p2), 2);
p2++;
}
else {
diff = difference(firstI, first.valueSparse(p1), 0);
normA += Math.pow(first.valueSparse(p1), 2);
p1++;
}
if (arg3 != null)
arg3.incrCoordCount();
distance = updateDistance(distance, diff);
if (distance > cutOffValue)
return Double.POSITIVE_INFINITY;
}
//do the post here, don't depends on other functions
//System.out.println(distance + " " + normA + " "+ normB);
distance = distance/Math.sqrt(normA)/Math.sqrt(normB);
distance = 1-distance;
if(distance < 0 || distance > 1)
System.err.println("unknown: " + distance);
return distance;
}
public double updateDistance(double currDist, double diff){
double result;
result = currDist;
result += diff;
return result;
}
public double difference(int index, double val1, double val2){
switch(m_Data.attribute(index).type()){
case Attribute.NOMINAL:
return Double.NaN;
//break;
case Attribute.NUMERIC:
return val1 * val2;
//break;
}
return Double.NaN;
}
@Override
public String getAttributeIndices() {
// TODO Auto-generated method stub
return null;
}
@Override
public Instances getInstances() {
// TODO Auto-generated method stub
return m_Data;
}
@Override
public boolean getInvertSelection() {
// TODO Auto-generated method stub
return false;
}
@Override
public void postProcessDistances(double[] arg0) {
// TODO Auto-generated method stub
}
@Override
public void setAttributeIndices(String arg0) {
// TODO Auto-generated method stub
}
@Override
public void setInstances(Instances arg0) {
// TODO Auto-generated method stub
m_Data = arg0;
}
@Override
public void setInvertSelection(boolean arg0) {
// TODO Auto-generated method stub
//do nothing
}
@Override
public void update(Instance arg0) {
// TODO Auto-generated method stub
//do nothing
}
@Override
public String[] getOptions() {
// TODO Auto-generated method stub
return null;
}
@Override
public Enumeration listOptions() {
// TODO Auto-generated method stub
return null;
}
@Override
public void setOptions(String[] arg0) throws Exception {
// TODO Auto-generated method stub
}
@Override
public String getRevision() {
// TODO Auto-generated method stub
return "Cosine Distance function writtern by Sgr, version " + version;
}
}
但我无法处理接下来的两个步骤,因为我不太熟悉 weka。
我在 weka 中看到了 simpleKmeans 的源代码,并观察到它创建了一个EuclideanDistance
类的实例,但我对进一步的过程一无所知。
请帮助我了解接下来要执行的两个步骤。如果余弦相似度的这种实现有错误,请找出答案。此外,如果有人可以为我的余弦实现修改 weka 中的 SimpleKmeans 代码,或者向我解释我应该在该代码中进行更改的地方,那将非常有帮助。