2

我开始了解递归任务和递归操作的实现。根据我的理解和一些 java 文档,我想出了下面的代码来将数组中的所有数字相加。

我需要帮助来纠正这个问题,并请帮助我指出我哪里出错了。

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

public static void main(String[] args) {

    ForkJoinPool pool = new ForkJoinPool(4);
    long[] numbers = {1,2,3,4,5,6,7,8,9};
    AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length -1 );
    ForkJoinTask<Long> submit = pool.submit(newTask);
    System.out.println(submit.join());
    
}
}

class AdditionTask extends RecursiveTask<Long> {

long[] numbers;
int start;
int end;

public AdditionTask(long[] numbers, int start, int end) {
    this.numbers = numbers;
    this.start = start;
    this.end = end;
}

@Override
protected Long compute() {

    if ((end - start) > 2) {

        int length = numbers.length;
        int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
        AdditionTask leftSide = new AdditionTask(numbers, 0, mid);

        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, mid+1, length-1);
        return rightSide.compute() + leftSide.join();

    } else {
        return numbers[0] + numbers[1];
    }


}
}

新代码 [已修复] 这是我修复的代码,似乎只适用于小数组。在下面的示例中,数组大小为 10000,并且总和是错误的。为什么它计算出错误的总和?

import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

public class ForkJoinPoolTest {

    public static void main(String[] args) {

        Random r = new Random();
        int low = 10000;
        int high = 100000;

        int size = 100000;

        long[] numbers = new long[size];
        int sum = 0;
        for (int i = 0; i < size; i++) {
            int n = r.nextInt(high - low) + low;
            numbers[i] = n;
            sum += numbers[i];
        }

        long s = System.currentTimeMillis();
        ForkJoinPool pool = new ForkJoinPool(1);
        AdditionTask newTask = new AdditionTask(numbers, 0, numbers.length-1);
        ForkJoinTask<Long> submit = pool.submit(newTask);
        System.out.println("Expected Answer: " + sum + ", Actual: " + submit.join());
        long e = System.currentTimeMillis();
        System.out.println("Total time taken: " + (e - s) + " ms in parallel Operation");




        long s2 = System.currentTimeMillis();
        System.out.println("Started: " + s2);

        int manualSum = 0;
        for (long number : numbers) {
            manualSum += number;
        }

        System.out.println("Expected Answer: " + sum + ", Actual: " + manualSum);
        long e2 = System.currentTimeMillis();
        System.out.println("Ended: " + e2);
        System.out.println("Total time taken: " + (e2 - s2) + " ms in sequential Operation");
    }
}

class AdditionTask extends RecursiveTask<Long> {

    long[] numbers;
    int start;
    int end;

    public AdditionTask(long[] numbers, int start, int end) {
        this.numbers = numbers;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Long compute() {


        int length = (start == 0) ? end +1 : (end - (start - 1));

        if (length > 2) {

            int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;
            
            AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
            leftSide.fork();
            
            AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

            Long rightSideLong = rightSide.compute();

            Long leftSideLong = leftSide.join();
            Long total = rightSideLong + leftSideLong;
            
            return total;

        } else {

            if (start == end) {
                return numbers[start];
            }
            return numbers[start] + numbers[end];

        }

    }
}
4

1 回答 1

2

您的并行计算的第二个版本是正确的。但是您的代码中的两个非并行计算在使用int它们的总和时都会被破坏,这对于大型数组会溢出。当您修复它们时,也使用long,它们将产生与您的并行计算相同的结果。

不过,还有一些需要改进的地方。首先,您应该摆脱这些条件:

int length = (start == 0) ? end +1 : (end - (start - 1));

int mid = (length % 2 == 0) ? length / 2 : (length - 1) / 2;

它们没有提供比更简单的好处

int length = end - (start - 1); // or end - start + 1

int mid = length / 2;

那么,一个高效的并行处理不应该尽可能多地分解,而是结合实际可实现的并行性。你可以使用getSurplusQueuedTaskCount()

@Override
protected Long compute() {
    int length = end - (start - 1);
    // only split when benefit from parallel processing is likely
    if (length > 2 && getSurplusQueuedTaskCount() < 2) {
        int mid = length / 2;

        AdditionTask leftSide = new AdditionTask(numbers, start, (start+mid));
        leftSide.fork();

        AdditionTask rightSide = new AdditionTask(numbers, (start+mid)+1, end);

        Long rightSideLong = rightSide.compute();

        // do in this thread if no worker thread has picked it up yet
        Long leftSideLong = leftSide.tryUnfork()? leftSide.compute(): leftSide.join();
        Long total = rightSideLong + leftSideLong;

        return total;
    } else { // do sequential
        long sum = 0;
        for(int ix = start; ix <= end; ix++) sum += numbers[ix];
        return sum;
    }
}
于 2020-10-19T11:54:49.410 回答