我有一个用 Python 实现的算法,它使用 numpy 库进行线性代数。我想在 Java 中为 Android 应用程序实现它,我尝试了很多库,比如 Jama。
检查一组 3D 点是否具有另一组相同模式的算法。
这是 Python 中运行良好的实现:
def compute_similarity(original_shape, shape, nearest_neighbors= False, debug= False):
reference_shape = np.copy(original_shape)
# calculate rotation matrix 3 x 3
U, _, V = svd(reference_shape.dot(shape.T))
rotation_matrix = U.dot(V)
rotated_shape = np.dot(rotation_matrix, shape)
# Perform procrustes analysis
ref_shape_mean_vals = np.mean(reference_shape, axis=1)[:, np.newaxis]
shape_mean_vals = np.mean(rotated_shape, axis=1)[:, np.newaxis]
translation = ref_shape_mean_vals - scale * shape_mean_vals
# Fit shape
fitted_shape = np.dot(rotation_matrix, shape) + translation
# Euclidean metric
score = np.linalg.norm(reference_shape-fitted_shape)
return score
如果点具有相同的模式,则分数应该是小于 1 的小值。
这是我在 Java 中的实现:
package com.moussa.zomzoom.Player;
import Jama.Matrix;
import Jama.SingularValueDecomposition;
public class Solver {
public static boolean isFit(double[][] _reference_shape,double[][] _shape) {
boolean result = true;
double threshold = 50;
// calculate rotation matrix 3 x 3
Matrix reference_shape = new Matrix(_reference_shape);
Matrix shape = new Matrix(_shape);
Matrix dot = reference_shape.times(shape.transpose());
SingularValueDecomposition SVD = new SingularValueDecomposition(dot);
Matrix U = SVD.getU();
Matrix V = SVD.getV();
Matrix rotation_matrix = U.times(V);
Matrix rotated_shape = rotation_matrix.times(shape);
// Perform procrustes analysis
double score = Double.MAX_VALUE;
Matrix reference_shape_mean = mean(reference_shape);
Matrix shape_mean = mean(rotated_shape);
// double scale = reference_shape.minus(reference_shape_mean).norm1() / shape.minus(shape_mean).norm1();
//Matrix translation = reference_shape_mean.minus(shape_mean.times(scale));
Matrix translation = reference_shape_mean.minus(shape_mean);
// Fit shape
//Matrix fitted_shape = rotation_matrix.times(shape).times(scale).plus(translation);
Matrix fitted_shape = rotated_shape.plus(translation);
// Euclidean metric
//double euclidean_dist = fitted_shape. distance2(reference_shape);
score = reference_shape.minus(fitted_shape).norm1();
result = score < threshold;
return result;
}
private static Matrix mean(Matrix matrix){
double[]mean_values=new double[3];
for(int i=0;i<matrix.getRowDimension();i++){
mean_values[0]+=matrix.get(i,0);
mean_values[1]+=matrix.get(i,1);
mean_values[2]+=matrix.get(i,2);
}
mean_values[0]/=matrix.getRowDimension();
mean_values[1]/=matrix.getRowDimension();
mean_values[2]/=matrix.getRowDimension();
double[][]m=new double[matrix.getRowDimension()][matrix.getColumnDimension()];
for(int i=0;i<matrix.getRowDimension();i++){
m[i]=mean_values;
}
return new Matrix(m);
}
}
Java 中的相同测试导致分数非常大 > 2000。我注意到 Jama 中的 SVD 值与 Python 不同。
我尝试了许多 Java 线性代数库,例如 Apache common-maths,并得到了相同的结果。