这是我为了好玩而写的 Levenshtein distance 的并行实现。我对结果感到失望。我在核心 i7 处理器上运行它,所以我有很多可用线程。但是,随着我增加线程数,性能会显着下降。我的意思是它实际上运行速度较慢,对于相同大小的输入有更多线程。
我希望有人可以看看我使用线程的方式和 java.util.concurrent 包,并告诉我是否做错了什么。我真的只对并行性无法按预期工作的原因感兴趣。我不希望读者看到这里发生的复杂索引。我相信我所做的计算是正确的。但即使不是,我认为随着线程池中线程数量的增加,我仍然应该看到接近线性的加速。
我已经包含了我使用的基准测试代码。我正在使用此处找到的库进行基准测试。第二个代码块是我用于基准测试的。
谢谢你的帮助 :)。
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
public class EditDistance {
private static final int MIN_CHUNK_SIZE = 5;
private final ExecutorService threadPool;
private final int threadCount;
private final String maxStr;
private final String minStr;
private final int maxLen;
private final int minLen;
public EditDistance(String s1, String s2, ExecutorService threadPool,
int threadCount) {
this.threadCount = threadCount;
this.threadPool = threadPool;
if (s1.length() < s2.length()) {
minStr = s1;
maxStr = s2;
} else {
minStr = s2;
maxStr = s1;
}
maxLen = maxStr.length();
minLen = minStr.length();
}
public int editDist() {
int iterations = maxLen + minLen - 1;
int[] prev = new int[0];
int[] current = null;
for (int i = 0; i < iterations; i++) {
int currentLen;
if (i < minLen) {
currentLen = i + 1;
} else if (i < maxLen) {
currentLen = minLen;
} else {
currentLen = iterations - i;
}
current = new int[currentLen * 2 - 1];
parallelize(prev, current, currentLen, i);
prev = current;
}
return current[0];
}
private void parallelize(int[] prev, int[] current, int currentLen,
int iteration) {
int chunkSize = Math.max(current.length / threadCount, MIN_CHUNK_SIZE);
List<Future<?>> futures = new ArrayList<Future<?>>(currentLen);
for (int i = 0; i < currentLen; i += chunkSize) {
int stopIdx = Math.min(currentLen, i + chunkSize);
Runnable worker = new Worker(prev, current, currentLen, iteration,
i, stopIdx);
futures.add(threadPool.submit(worker));
}
for (Future<?> future : futures) {
try {
Object result = future.get();
if (result != null) {
throw new RuntimeException(result.toString());
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (ExecutionException e) {
// We can only finish the computation if we complete
// all subproblems
throw new RuntimeException(e);
}
}
}
private void doChunk(int[] prev, int[] current, int currentLen,
int iteration, int startIdx, int stopIdx) {
int mergeStartIdx = (iteration < minLen) ? 0 : 2;
for (int i = startIdx; i < stopIdx; i++) {
// Edit distance
int x;
int y;
int leftIdx;
int downIdx;
int diagonalIdx;
if (iteration < minLen) {
x = i;
y = currentLen - i - 1;
leftIdx = i * 2 - 2;
downIdx = i * 2;
diagonalIdx = i * 2 - 1;
} else {
x = i + iteration - minLen + 1;
y = minLen - i - 1;
leftIdx = i * 2;
downIdx = i * 2 + 2;
diagonalIdx = i * 2 + 1;
}
int left = 1 + ((leftIdx < 0) ? iteration + 1 : prev[leftIdx]);
int down = 1 + ((downIdx < prev.length) ? prev[downIdx]
: iteration + 1);
int diagonal = penalty(x, y)
+ ((diagonalIdx < 0 || diagonalIdx >= prev.length) ? iteration
: prev[diagonalIdx]);
int dist = Math.min(left, Math.min(down, diagonal));
current[i * 2] = dist;
// Merge prev
int mergeIdx = i * 2 + 1;
if (mergeIdx < current.length) {
current[mergeIdx] = prev[mergeStartIdx + i * 2];
}
}
}
private int penalty(int maxIdx, int minIdx) {
return (maxStr.charAt(maxIdx) == minStr.charAt(minIdx)) ? 0 : 1;
}
private class Worker implements Runnable {
private final int[] prev;
private final int[] current;
private final int currentLen;
private final int iteration;
private final int startIdx;
private final int stopIdx;
Worker(int[] prev, int[] current, int currentLen, int iteration,
int startIdx, int stopIdx) {
this.prev = prev;
this.current = current;
this.currentLen = currentLen;
this.iteration = iteration;
this.startIdx = startIdx;
this.stopIdx = stopIdx;
}
@Override
public void run() {
doChunk(prev, current, currentLen, iteration, startIdx, stopIdx);
}
}
public static void main(String args[]) {
int threadCount = 4;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
EditDistance ed = new EditDistance("Saturday", "Sunday", threadPool,
threadCount);
System.out.println(ed.editDist());
threadPool.shutdown();
}
}
EditDistance 内部有一个私有内部类 Worker。每个工作人员负责使用 EditDistance.doChunk 填充当前数组的范围。EditDistance.parallelize 负责创建这些工作人员,并等待他们完成任务。
我用于基准测试的代码:
import java.io.PrintStream;
import java.util.concurrent.*;
import org.apache.commons.lang3.RandomStringUtils;
import bb.util.Benchmark;
public class EditDistanceBenchmark {
public static void main(String[] args) {
if (args.length != 2) {
System.out.println("Usage: <string length> <thread count>");
System.exit(1);
}
PrintStream oldOut = System.out;
System.setOut(System.err);
int strLen = Integer.parseInt(args[0]);
int threadCount = Integer.parseInt(args[1]);
String s1 = RandomStringUtils.randomAlphabetic(strLen);
String s2 = RandomStringUtils.randomAlphabetic(strLen);
ExecutorService threadPool = Executors.newFixedThreadPool(threadCount);
Benchmark b = new Benchmark(new Benchmarker(s1, s2, threadPool,threadCount));
System.setOut(oldOut);
System.out.println("threadCount: " + threadCount +
" string length: "+ strLen + "\n\n" + b);
System.out.println("s1: " + s1 + "\ns2: " + s2);
threadPool.shutdown();
}
private static class Benchmarker implements Runnable {
private final String s1, s2;
private final int threadCount;
private final ExecutorService threadPool;
private Benchmarker(String s1, String s2, ExecutorService threadPool, int threadCount) {
this.s1 = s1;
this.s2 = s2;
this.threadPool = threadPool;
this.threadCount = threadCount;
}
@Override
public void run() {
EditDistance d = new EditDistance(s1, s2, threadPool, threadCount);
d.editDist();
}
}
}