0

给定以下代码

public class StreamSpliteratorTest {

    public static void main(String[] args) {
        var ids = Stream.generate(() -> RandomStringUtils.randomAlphanumeric(5))
            .limit(100)
            .collect(Collectors.toList());
        get(ids)
            .forEach(resultMap -> {
                System.out.printf("Got result map with size %s%n", resultMap.size());
            });
    }

    static Stream<Map<String, String>> get(Collection<String> ids) {
        var remainginIds = new HashSet<>(ids);
        var initialCount = remainginIds.size();
        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<>() {

            @Override
            public boolean hasNext() {
                return !remainginIds.isEmpty();
            }

            @Override
            public Map<String, String> next() {
                Map<String, String> result = Map.of();

                try {
                    var chunk = remainginIds.stream().limit(10).collect(Collectors.toSet());
                    remainginIds.removeAll(chunk);
                    result = fetch(chunk).get(5, TimeUnit.SECONDS);
                    System.out.printf("%s of %s ids done%n", initialCount - remainginIds.size(), initialCount);
                } catch (Exception e) {
                    System.err.printf("Request thread pool was interrupted: %s%n", e.getMessage());
                }

                return result;
            }
        }, Spliterator.IMMUTABLE), true);
    }

    static CompletableFuture<Map<String, String>> fetch(Collection<String> ids) {
        var delay = CompletableFuture.delayedExecutor(1, TimeUnit.SECONDS);
        return CompletableFuture
            .supplyAsync(() -> ids.stream().collect(Collectors.toMap(e -> e, e -> e)), delay);
    }
}

结果是

10 of 100 ids done
20 of 100 ids done
30 of 100 ids done
40 of 100 ids done
50 of 100 ids done
60 of 100 ids done
70 of 100 ids done
80 of 100 ids done
90 of 100 ids done
100 of 100 ids done
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10

我很困惑为什么一次执行的结果next()没有立即传递给调用代码中的下一个使用者?我希望这会导致以下输出:


10 of 100 ids done
Got result map with size 10
20 of 100 ids done
Got result map with size 10
30 of 100 ids done
Got result map with size 10
40 of 100 ids done
Got result map with size 10
50 of 100 ids done
Got result map with size 10
60 of 100 ids done
Got result map with size 10
70 of 100 ids done
Got result map with size 10
80 of 100 ids done
Got result map with size 10
90 of 100 ids done
Got result map with size 10
100 of 100 ids done
Got result map with size 10

我在这里做错了什么?

4

2 回答 2

2

实际上,我认为正在发生的事情是该方法StreamSupport.stream()实际上首先遍历 Iterator 以创建 Stream,然后再使用 forEach。

StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<>() {

forEach 方法不是为每个 (Next) 物理调用的,而是在之前从 Iterator 填充的 (Stream) 上调用的。这就是为什么你看到你看到的序列。

于 2022-02-04T12:16:31.203 回答
1

Spliterator返回的 from尝试Spliterators.spliteratorUnknownSize对迭代器中的许多元素进行批处理,以提高并行性能。

你可以在这里看到:https ://github.com/openjdk/jdk/blob/51b53a821bb3cfb962f80a637f5fb8cde988975a/src/java.base/share/classes/java/util/Spliterators.java#L1828

sequential()为避免这种情况,您可以使用不同的拆分器或通过调用before 方法将流切换到顺序模式forEach(...)

继续使用并行流的一种方法是将批处理移到流之前并使用以下Arrays.stream方法:

static Stream<Map<String, String>> get(Collection<String> ids) {
    List<String>[] chunks = new List[ids.size() / 10 + 1];
    Arrays.setAll(chunks, i -> new ArrayList<>());
    int i = 0;
    for(String id : ids){
        chunks[i++ / 10].add(id);
    }
    AtomicInteger done = new AtomicInteger(0);
    int initialCount = ids.size();
    return Arrays.stream(chunks).map(c -> {
        Map<String, String> result = Map.of();
        try {
            result = fetch(c).get(5, TimeUnit.SECONDS);
            System.out.printf("%s of %s ids done%n", done.addAndGet(c.size()), initialCount);
        } catch (Exception e) {
            System.err.printf("Request thread pool was interrupted: %s%n", e.getMessage());
        }
        return result;
    }).parallel();
}
于 2022-02-04T12:22:20.203 回答