2

This question is NOT about how to use a ThreadLocal. My question is about the side effect of the ForkJoinPool continuation of ForkJoinTask.compute() which breaks the ThreadLocal contract.

In a ForkJoinTask.compute(), I pull an arbitrary static ThreadLocal.

The value is some arbitrary stateful object but not stateful beyond the end of the compute() call. In other words, I prepare the threadlocal object/state, use it, then dispose.

In principle you would put that state in the ForkJoinTasK, but just assume this thread local value is in a 3rd party lib I cannot change. Hence the static threadlocal, as it is a resource that all tasks instances will share.

I anticipated, tested and proved that simple ThreadLocal gets initialized only once, of course. This means that due to thread continuation beneath the ForkJoinTask.join() call, my compute() method can get called again before it even exited. This exposes the state of the object being used on the previous compute call, many stackframes higher.

How do you solve that undesirable exposure issue?

The only way I currently see is to ensure new threads for every compute() call, but that defeats the F/J pool continuation and could dangerously explode the thread count.

Isn't there something to do in the JRE core to backup TL that changed since the first ForkJoinTask and revert the entire threadlocal map as if every task.compute is the first to run on the thread?

Thanks.

package jdk8tests;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.atomic.AtomicInteger;

public class TestForkJoin3 {

    static AtomicInteger nextId = new AtomicInteger();
    static long T0 = System.currentTimeMillis();
    static int NTHREADS = 5;
    static final ThreadLocal<StringBuilder> myTL = ThreadLocal.withInitial( () -> new StringBuilder());

    static void log(Object msg) {
        System.out.format("%09.3f %-10s %s%n", new Double(0.001*(System.currentTimeMillis()-T0)), Thread.currentThread().getName(), " : "+msg);
    }

    public static void main(String[] args) throws Exception {
        ForkJoinPool p = new ForkJoinPool(
                NTHREADS,
                pool -> {
                    int id = nextId.incrementAndGet(); //count new threads
                    log("new FJ thread "+ id);
                    ForkJoinWorkerThread t = new ForkJoinWorkerThread(pool) {/**/};
                    t.setName("My FJThread "+id);
                    return t;
                },
                Thread.getDefaultUncaughtExceptionHandler(),
                false
        );

        LowercasingTask t = new LowercasingTask("ROOT", 3);
        p.invoke(t);

        int nt = nextId.get();
        log("number of threads was "+nt);
        if(nt > NTHREADS)
            log(">>>>>>> more threads than prescribed <<<<<<<<");
    }


    //=====================

    static class LowercasingTask extends RecursiveTask<String> {
        String name;
        int level;
        public LowercasingTask(String name, int level) {
            this.name = name;
            this.level = level;
        }

        @Override
        protected String compute() {
            StringBuilder sbtl = myTL.get();
            String initialValue = sbtl.toString();
            if(!initialValue.equals(""))
                log("!!!!!! BROKEN ON START!!!!!!! value = "+ initialValue);

            sbtl.append(":START");

            if(level>0) {
                log(name+": compute level "+level);
                try {Thread.sleep(10);} catch (InterruptedException e) {e.printStackTrace();}

                List<LowercasingTask> tasks = new ArrayList<>();
                for(int i=1; i<=9; i++) {
                    LowercasingTask lt = new LowercasingTask(name+"."+i, level-1);
                    tasks.add(lt);
                    lt.fork();
                }

                for(int i=0; i<tasks.size(); i++) { //this can lead to compensation threads due to l1.join() method running lifo task lN
                //for(int i=tasks.size()-1; i>=0; i--) { //this usually has the lN.join() method running task lN, without compensation threads.
                    tasks.get(i).join();
                }

                log(name+": returning from joins");

            }

            sbtl.append(":END");

            String val = sbtl.toString();
            if(!val.equals(":START:END"))
                log("!!!!!! BROKEN AT END !!!!!!! value = "+val);

            sbtl.setLength(0);
            return "done";
        }
    }

}
4

1 回答 1

2

我不相信。不是一般的,特别是对于ForkJoinTask那些期望任务是孤立对象上的纯函数的地方。

有时可以在开始和自己的任务工作之前更改任务的顺序以分叉和加入。这样子任务将在返回之前初始化并处理线程本地。如果这是不可能的,也许您可​​以将线程本地视为堆栈并推送、清除和恢复每个连接周围的值。

于 2016-01-31T23:18:42.163 回答