我编写了一个程序来计算标准化互信息以评估社区检测。但我得到的 nmi 值高于 1。通常它应该在 0 和 1 之间。我在http://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html中实现公式
我的代码有什么问题?
这是我的代码:
public void calNmi(int size, Map<Integer, List<Integer>> realcom, Map<Integer, List<Integer>> foundcom){
List<Double> hf = new ArrayList<>();
List<Double> hr = new ArrayList<>();
double I = 0;
for (Map.Entry<Integer, List<Integer>> found : foundcom.entrySet()) {
int comSize = found.getValue().size();
double p = (double) comSize/size;
hf.add( -p * (Math.log(p)/Math.log(2)));
}
for (Map.Entry<Integer, List<Integer>> real : realcom.entrySet()) {
int comSize = real.getValue().size();
double p = (double) comSize/size;
hr.add(-p * (Math.log(p)/Math.log(2)));
}
int i =0;
for (Map.Entry<Integer, List<Integer>> real : realcom.entrySet()) {
if(i%100 == 0){
System.out.println(i);
}
List<Integer> rCom = real.getValue();
double pr = (double) rCom.size()/size;
for (Map.Entry<Integer, List<Integer>> found : foundcom.entrySet()) {
List<Integer> fCom = found.getValue();
double pf = (double) fCom.size()/size;
int intersect = foundIntersect(rCom, fCom);
if(intersect != 0) {
double p = (double) intersect / size;
I += p * (Math.log(p / (pr * pf))) / Math.log(2);
}
}
i++;
}
double sumH = 0;
for (Double sumh : hr) {
sumH += sumh;
}
double sumF = 0;
for (Double sumf : hf) {
sumF += sumf;
}
System.out.println("I = " + I + " hf = "+ sumF + " hr " + sumH);
double nmi = (2 * I)/(sumF + sumH);
System.out.println("nmi = " + nmi);
}
private int foundIntersect(List<Integer> rCom, List<Integer> fCom) {
int count = 0;
for (Integer r : rCom) {
if(fCom.contains(r)) {
count++;
}
}
return count;
}