您的当前位置:首页正文

ForkJoin框架详解 一张图搞明白工作窃取(work-stealing)机制

2024-11-29 来源:个人技术集锦

1 ForkJoin框架 

1.1 ForkJoin框架

ForkJoinPool一种ExecutorService的实现,运行ForkJoinTask任务。ForkJoinPool区别于其它ExecutorService,主要是因为它采用了一种工作窃取(work-stealing)的机制。所有被ForkJoinPool管理的线程尝试窃取提交到池子里的任务来执行,执行中又可产生子任务提交到池子中。

    ForkJoinPool维护了一个WorkQueue的数组(数组长度是2的整数次方,自动增长)。每个workQueue都有任务队列(ForkJoinTask的数组),并且用base、top指向任务队列队尾和队头。work-stealing机制就是工作线程挨个扫描任务队列,如果队列不为空则取队尾的任务并执行。示意图如下:

1.2 demo小程序

    创建一个包含2个worker线程的pool,main线程提交2个任务(task-1、task-2),触发worker线程工作,task-1任务fork出4个子任务。main线程负责同步两个worker线程的工作进度。demo小程序演示了worker-2窃取worker-1的子任务:

package com.focuse.jdkdemo.concurrent;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.LockSupport;

/**
 * @date :Created in 2020/2/16 上午11:10
 * @description:
 * @modified By:
 */
public class ForkJoinDemo {
    private static Thread worker1 = null;
    private static Thread worker2 = null;
    private static AtomicBoolean  worker1Park = new AtomicBoolean();
    private static AtomicBoolean  worker2Park = new AtomicBoolean();
    static {
        worker1Park.set(false);
        worker2Park.set(false);
    }

    public abstract static class ForkJoinTaskDemo extends ForkJoinTask {
        @Override
        public Object getRawResult() {
            return null;
        }

        @Override
        protected void setRawResult(Object value) {

        }
    }

    public static void setParkFlag() {
        if (Thread.currentThread().getName().equals("worker-1")) {
            //自旋设置, 保证与main线程同步
            while (!worker1Park.compareAndSet(false, true)) {

            }
        } else if (Thread.currentThread().getName().equals("worker-2")) {
            //自旋设置, 保证与main线程同步
            while (!worker2Park.compareAndSet(false, true)) {

            }
        }
    }

    private static Runnable task1 = new Runnable() {
        @Override
        public void run() {
            worker1 = Thread.currentThread();
            worker1.setName("worker-1");
            System.out.println(Thread.currentThread().getName() + " execute task1");
            //暂停并且设置暂停flag以便main结束自旋
            setParkFlag();
            LockSupport.park();
            ForkJoinTask task11 = new ForkJoinTaskDemo() {
                @Override
                protected boolean exec() {
                    System.out.println(Thread.currentThread().getName() + " execute task1-1");
                    //暂停并且设置暂停flag以便main结束自旋
                    setParkFlag();
                    LockSupport.park();
                    return true;
                }
            };
            task11.fork();
            ForkJoinTask task12 = new ForkJoinTaskDemo() {
                @Override
                protected boolean exec() {
                    System.out.println(Thread.currentThread().getName() + " execute task1-2");
                    //暂停并且设置暂停flag以便main结束自旋
                    setParkFlag();
                    LockSupport.park();
                    return true;
                }
            };
            task12.fork();
            ForkJoinTask task13 = new ForkJoinTaskDemo() {
                @Override
                protected boolean exec() {
                    System.out.println(Thread.currentThread().getName() + " execute task1-3");
                    //暂停并且设置暂停flag以便main结束自旋
                    setParkFlag();
                    LockSupport.park();
                    return true;
                }
            };
            task13.fork();
            ForkJoinTask task14 = new ForkJoinTaskDemo() {
                @Override
                protected boolean exec() {
                    System.out.println(Thread.currentThread().getName() + " execute task1-4");
                    //暂停并且设置暂停flag以便main结束自旋
                    setParkFlag();
                    LockSupport.park();
                    return true;
                }
            };
            task14.fork();
            //提交4个子任务后暂停
            setParkFlag();
            LockSupport.park();
        }
    };

    private static Runnable task2 = new Runnable() {
        @Override
        public void run() {
            worker2 = Thread.currentThread();
            worker2.setName("worker-2");
            System.out.println(Thread.currentThread().getName() + " execute task2");
            //暂停并且设置暂停flag以便main结束自旋
            setParkFlag();
            LockSupport.park();
        }
    };

    public static void main(String[] args) {
        //只用两个线程,方便测试
        ForkJoinPool pool = new ForkJoinPool(2);
        //step-1、step-2 提交2各任务,期望worker-1 worker-2各执行一个任务
        System.out.println("step-1 step-2: 提交2个任务,触发2个worker线程");
        pool.submit(task1);
        pool.submit(task2);

        while (!worker1Park.get()) {
        }
        //step-3  唤醒worker-1 产生4个任务 然后worker-1继续暂停
        System.out.println("\n*******************************************");
        System.out.println("step-3 唤醒worker-1 产生4个子任务");
        LockSupport.unpark(worker1);
        //自旋设置, 保证与worker-1线程同步
        while (!worker1Park.compareAndSet(true, false)) {
        }


        while (!worker1Park.get() || !worker2Park.get()) {
        }
        //step-4 唤醒worker-1 worker-2
        System.out.println("\n*******************************************");
        System.out.println("step-4 唤醒worker-1、worker-2:worker-1弹出task1-4执行,worker-2窃取task1-1执行");
        LockSupport.unpark(worker1);
        LockSupport.unpark(worker2);
        //自旋设置, 保证与worker-1线程同步
        while (!worker1Park.compareAndSet(true, false)) {
        }
        //自旋设置, 保证与worker-2线程同步
        while (!worker2Park.compareAndSet(true, false)) {
        }

        while (!worker1Park.get() || !worker2Park.get()) {
        }
        //step-5 唤醒worker-1 worker-2
        System.out.println("\n*******************************************");
        System.out.println("step-5 唤醒worker-1、worker-2:worker-1弹出task1-3执行,worker-2窃取task1-2执行");
        LockSupport.unpark(worker1);
        LockSupport.unpark(worker2);
        //自旋设置, 保证与worker-1线程同步
        while (!worker1Park.compareAndSet(true, false)) {
        }
        //自旋设置, 保证与worker-2线程同步
        while (!worker2Park.compareAndSet(true, false)) {
        }


        while (!worker1Park.get() || !worker2Park.get()) {
        }
        //唤醒worker-1 worker-2 结束
        LockSupport.unpark(worker1);
        LockSupport.unpark(worker2);
        //自旋设置, 保证与worker-1线程同步
        while (!worker1Park.compareAndSet(true, false)) {
        }
        //自旋设置, 保证与worker-2线程同步
        while (!worker2Park.compareAndSet(true, false)) {
        }

    }
}

运行结果如下:

2 源码解读

    这一part围绕线程池(ForkJoinPool)结构已经线程池的关键动作讲一下过程,线程池的关键动作无非是提交任务、运行任务、获取任务结果。不过,对于ForkJoinPool有还有任务fork动作。

    ForkJoinPool里面大量用到bit运算。做一下简短说明:计算机的运算以补码计算。补码怎么来?正数的补码就是原码,最高位是符号位为0;负数的补码是其正数的补码按位取反(包括符号位),最后加1。

2.1 ForkJoinPool 和 WorkQueue

    ForkJoinPool几个主要的成员变量说明如下:

  • config 是创建ForkJoinPool的配置,int类型32bits,高16位表示pool的mode(FIFO或LIFO),低16位表示parallelism(并行度,默认大小可用处理器数java.lang.Runtime#availableProcessors);
  • ctl 是ForkJoinPool的主要控制字段,long类型64bits,ctl不同的bit位表示不同的含义;
    • 高16位(63~48)表示活跃的线程,值为活跃线程数减去parallelism(补码表示),初始值是0-parallelism,工作线程激活则加1,去激活则减1。当累积加了parallelism时第63bit位翻转为0,则不允许再激活工作线程。
    • 第47~32位表示当前所有工作线程(包括未激活的),值为所有工作线程数-parallelism(补码表示),创建线程则加1,终止线程则减1。当累积加了parallelism时第47位翻转位0,则不允许再创建线程;
    • 第31~16位表示非激活线程链中top线程的版本计数和线程状态,与第15~0位合起来看;
    • 第15~0位表示非激活线程链中top线程的本地WorkQueue在ForkJoinPool.workQueues数组中下标索引,第31~0位合起来的值实际是非激活线程链中top线程的本地WorkQueue.scanState
  • workQueues 是ForkJoinPool维护一个WorkQueue数组,奇数下标的WorkQueue关联一个worker线程,偶数下标的WorkQueue用来接收外部提交的任务(非worker线程提交的任务);
  • factory 创建worker线程的工厂;

源码如下(笔者添加了注释):

public class ForkJoinPool extends AbstractExecutorService {
    ......
    //笔者注:runState的bit位:SHUTDOWN是负数(int的最高bit位为1),其余的是2的整数次方
    private static final int  RSLOCK     = 1;
    private static final int  RSIGNAL    = 1 << 1;
    private static final int  STARTED    = 1 << 2;
    private static final int  STOP       = 1 << 29;
    private static final int  TERMINATED = 1 << 30;
    private static final int  SHUTDOWN   = 1 << 31;

    // Instance fields
    //笔者注:ctl是ForkJoinPool的控制字段。long是64位bit,
    //最高16位(63~48)表示活跃的线程,值为活跃的线程-parallelism(补码表示),
    //第47~32位表示当前所有工作线程(包括未激活的),值为所有工作线程数-parallelism(补码表示),
    //第31~16位表示waiters线程链中top线程的版本计数和线程状态
    //第15~0位表示waiters线程链中top线程的本地WorkQueue在ForkJoinPool.workQueues的下标索引
    volatile long ctl;                   // main pool control
    volatile int runState;               // lockable status
    //笔者注:config高16位是mode(FIFO或者LIFO),低16位是parallelism
    final int config;                    // parallelism, mode
    //笔者注:防止index冲突的随机数种子
    int indexSeed;                       // to generate worker index
    //笔者注:WorkQueue数组,奇数WorkQueue有worker线程,偶数WorkQueue接收外部提交任务
    volatile WorkQueue[] workQueues;     // main registry
    //工作线程的创建工厂
    final ForkJoinWorkerThreadFactory factory;
    final UncaughtExceptionHandler ueh;  // per-worker UEH
    final String workerNamePrefix;       // to create worker name string
    //工作窃取的计数
    volatile AtomicLong stealCounter;    // also used as sync monitor
    ......
}

WorkQueue的几个重要成员变量说明如下:

  • scanState int类型32bits,各bit位含义如下:
    • 第31位表示线程状态(1非激活),第30~16位表示版本计数;
    • 第0位表示worker线程是否在运行任务(1-scanning,0-busy),这里有个小技巧在创建worker线程的WorkQueue时scanState的第15~0位初始化为ForkJoinPool.workQueues的下标(worker线程的WorkQueue的下标是奇数),当worker线程运行任务时第0位设置0(busy),任务运行结束第0位又设置1(恢复为奇数),所以scanState的第15~0又可以表示在ForkJoinPool.workQueues数组的下标索引
  • stackPred  当worker线程从激活变为非激活时设置值,且值为ForkJoinPool的ctl的低32位(实际是前一个非激活线程),这样就形成了一个非激活线程链;
  • config 高16位是mode(FIFO或者LIFO),低16位是ForkJoinPool.workQueues的下标
  • base 任务队列的队尾,工作窃取就是窃取base指向的任务;
  • top 任务队列的队头(指向空),下一个push的位置;
  • array 任务队列;
  • pool 所属的ForkJoinPool实例;
  • owner 所属的worker线程,如果在ForkJoinPool.workQueues数组中下标是奇数,则不为空。

源码如下(笔者添加了注释):

static final class WorkQueue {
    ......
    // Instance fields
    //笔者注:最高位为1表示非激活,第30~16位版本计数,第0表示是否在运行任务(1-scanning,0-busy)
    //
    volatile int scanState;    // versioned, <0: inactive; odd:scanning
    //笔者注:实际是前一个非激活线程,这样就形成了一个waiters线程链
    int stackPred;             // pool stack (ctl) predecessor
    int nsteals;               // number of steals
    int hint;                  // randomization and stealer index hint
    //高16位是mode(FIFO或LIFO),低16位是ForkJoinPool.workQueues的下标
    int config;                // pool index and mode
    volatile int qlock;        // 1: locked, < 0: terminate; else 0
    //笔者注:任务队列的队尾,工作窃取就是窃取base指向的任务
    volatile int base;         // index of next slot for poll
    //笔者注:任务队列的队头(指向空),下一个push的位置
    int top;                   // index of next slot for push
    //笔者注:任务队列
    ForkJoinTask<?>[] array;   // the elements (initially unallocated)
    //笔者注:所属的ForkJoinPool实例
    final ForkJoinPool pool;   // the containing pool (may be null)
    //笔者注:所属的worker线程,如果在ForkJoinPool.workQueues数组中的下标是奇数则不为空
    final ForkJoinWorkerThread owner; // owning thread or null if shared
    volatile Thread parker;    // == owner during call to park; else null
    volatile ForkJoinTask<?> currentJoin;  // task being joined in awaitJoin
    volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer
    ......
}

2.2 提交任务

    ForkJoinPool作为ExecutorService的一个实现类,有submit方法提交任务,直接贴源码出来如下:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    //提交任务
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }
    ......
}

   

submit方法判断了一下任务是否为null,然后直接调用externalPush,源码如下:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    /**
    **/
    final void externalPush(ForkJoinTask<?> task) {
        WorkQueue[] ws; WorkQueue q; int m;
        int r = ThreadLocalRandom.getProbe();
        int rs = runState;
        //笔者注:(1)计算一个偶数下标,如果该下标下WorkQueue不为空则尝试添加到该WorkQueue
        if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
            U.compareAndSwapInt(q, QLOCK, 0, 1)) {// 笔者注:(2)加锁锁住改WorkQueue
            ForkJoinTask<?>[] a; int am, n, s;
            //笔者注:WorkQueue的任务队列不为null且未满 a.length - 1 > q.top - q.base
            if ((a = q.array) != null &&
                (am = a.length - 1) > (n = (s = q.top) - q.base)) {
                int j = ((am & s) << ASHIFT) + ABASE;
                U.putOrderedObject(a, j, task);
                U.putOrderedInt(q, QTOP, s + 1);
                //笔者注:释放锁
                U.putIntVolatile(q, QLOCK, 0);
                //笔者注:唤醒工作线程
                if (n <= 1)
                    signalWork(ws, q);
                return;
            }
            //笔者注:释放锁
            U.compareAndSwapInt(q, QLOCK, 1, 0);
        }
        externalSubmit(task);
    }
    ......
}

submit的任务是提交到偶数下标的workQueue中。函数externalPush先计算一个下标位置"m & r & SQMASK",这个位置是偶数下标,如何保证是偶数?"m & r & SQMASK"又是什么含义?m是pool.workQueues的数组size - 1(即2的整数次方-1),跟m做&运算能保证下标不超过workQueues.size;  r是当前线程的局部变量,取这个值尽可能避免冲突;SQMASK是一个常量0x007e,第0bit位是0,这就保证了下标计算出来一定是偶数。

  • (1) 判断对应下标的workQueues的元素不为空且pool的状态正常;
  • (2)加锁 U.compareAndSwapInt(q, QLOCK, 0, 1)) ;
  • (3)如果对应WorkQueue中任务队列已初始化(不等于null)且未满,则加入该队列;
  • (4)释放锁 U.compareAndSwapInt(q, QLOCK, 1, 0);
  • (5) 如果添加成功,唤醒工作线程 sinalWork;
  • (6)如果没有成功添加,调用externalSubmit方法;

我们再看一下externalSubmit函数源码如下:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    private void externalSubmit(ForkJoinTask<?> task) {
        int r;                                    // initialize caller's probe
        if ((r = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();
            r = ThreadLocalRandom.getProbe();
        }
        //笔者注:循环操作,每次循环的动作都非常小,符合条件就跳出来
        for (;;) {
            WorkQueue[] ws; WorkQueue q; int rs, m, k;
            boolean move = false;
            //笔者注: runState为负数,则表示shutdown,终止线程池并跳出循环
            if ((rs = runState) < 0) {
                tryTerminate(false, false);     // help terminate
                throw new RejectedExecutionException();
            }
            //笔者注:pool未初始化,则初始化然后继续循环
            else if ((rs & STARTED) == 0 ||     // initialize
                     ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
                int ns = 0;
                rs = lockRunState();
                try {
                    if ((rs & STARTED) == 0) {
                        U.compareAndSwapObject(this, STEALCOUNTER, null,
                                               new AtomicLong());
                      //笔者注:这里保证workQueues是2的整数次方,方法巧妙,n经过一系列位运算
                       //变成了一个从第0位开始出现连续1的直到出现一个bit位是0,然后就一直是0,
                       //这样最后再加1就变成了2的整数次方。
                        // create workQueues array with size a power of two
                        int p = config & SMASK; // ensure at least 2 slots
                        int n = (p > 1) ? p - 1 : 1;
                        n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                        n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                        workQueues = new WorkQueue[n];
                        ns = STARTED;
                    }
                } finally {
                    unlockRunState(rs, (rs & ~RSLOCK) | ns);
                }
            }
            //笔者注:如果workQueue不为空,则加锁尝试添加。添加成功唤醒工作线程并跳出,否则继续循环
            else if ((q = ws[k = r & m & SQMASK]) != null) {
                if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                    ForkJoinTask<?>[] a = q.array;
                    int s = q.top;
                    boolean submitted = false; // initial submission or resizing
                    try {                      // locked version of push
                        if ((a != null && a.length > s + 1 - q.base) ||
                            (a = q.growArray()) != null) {
                            int j = (((a.length - 1) & s) << ASHIFT) + ABASE;
                            U.putOrderedObject(a, j, task);
                            U.putOrderedInt(q, QTOP, s + 1);
                            submitted = true;
                        }
                    } finally {
                        U.compareAndSwapInt(q, QLOCK, 1, 0);
                    }
                    if (submitted) {
                        signalWork(ws, q);
                        return;
                    }
                }
                move = true;                   // move on failure
            }
            // 笔者注:如果pool未被锁,则创建一个WorkQueue,注意k在前面一个elseif赋值了一个偶数
            else if (((rs = runState) & RSLOCK) == 0) { // create new queue
                q = new WorkQueue(this, null);
                q.hint = r;
                q.config = k | SHARED_QUEUE;
                q.scanState = INACTIVE;
                rs = lockRunState();           // publish index
                if (rs > 0 &&  (ws = workQueues) != null &&
                    k < ws.length && ws[k] == null)
                    ws[k] = q;                 // else terminated
                unlockRunState(rs, rs & ~RSLOCK);
            }
            else
                move = true;                   // move if busy
            //笔者注:重新获取下标
            if (move)
                r = ThreadLocalRandom.advanceProbe(r);
        }
    }
    ......
}

这里是一个循环,每次循环依次判断条件,符合条件就跳出来:

  • (1)判断runState是否为负数,如果是负数,则表示shutdown,终止线程池并返回;
  • (2)否则继续判断,如果pool未初始化,则初始化之后然后开始另一次循环从(1)开始判断;
  • (3)否则继续判断,计算一个偶数下标k,如果workQueues[k]不为null,尝试添加任务进去,成功则唤醒工作线程并返回;
  • (4)否则继续判断,如果pool未被其它线程锁住,则创建一个WorkQueue赋值给workQueues[k],注意k在(3)中已经赋予了一个偶数,然后开始另一次循环从(1)开始判断;
  •  (5)否则,开始另一次循环从(1)开始判断;

submit任务添加完成,又1个疑问没揭开:为什么判断任务队列(workQueue.array)未满用a.length - 1 > q.top - q.base,感觉不需要-1? 不过这不影响阅读,也不影响整体对ForkJoin的理解。先记录下来,未来可能整明白。

    此外,ForkJoin框架的还可以通过ForkJoinTask.fork来添加任务,源码如下:

public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
    ......
    //如果是work线程调用fork则添加到work线程的本地队列里面,否则添加到commonPool池子里面    
    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }
    ......
}

2.3 运行任务

     这一节主要说明工作线程是怎么启动的?主要是添加任务时调用的signalWork:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    final void signalWork(WorkQueue[] ws, WorkQueue q) {
        long c; int sp, i; WorkQueue v; Thread p;
        //笔者注:ctl的高16位是激活线程数-parallelism的反码,最高位为1表示激活线程还可以增加
        while ((c = ctl) < 0L) {                       // too few active
            //笔者注:ctl的第32位是非激活线程链的top线程,为0表示没有非激活线程
            if ((sp = (int)c) == 0) {                  // no idle workers
              //笔者注:ctl的47~32位是线程总数-parallelism的反码,最高为1表示总的线程还可以加
                if ((c & ADD_WORKER) != 0L)            // too few workers
                    tryAddWorker(c);
                break;
            }
            if (ws == null)                            // unstarted/terminated
                break;
            if (ws.length <= (i = sp & SMASK))         // terminated
                break;
            if ((v = ws[i]) == null)                   // terminating
                break;
            //笔者注:以下逻辑是获取非激活线程链的top线程然后激活,再激活之前要将下一个非激活线程(stackPred表示)放到ctl里面
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
            int d = sp - v.scanState;                  // screen CAS
            long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
            if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
                v.scanState = vs;                      // activate v
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
            if (q != null && q.base == q.top)          // no more work
                break;
        }
    }
    ......
}
  • 先根据pool的控制字段ctl判断是否需要添加新的worker线程,如果需要添加并返回
  • 否则,找到非激活线程链的top线程并激活,主要在激活之前,先要将下一个非激活线程(stackPred指向)设置到ctl,使之成为top线程

创建worker则是调用pool中的factory(ForkJoinWorkerThreadFactory)创建一个工作线程,并将其注册到pool中,注册过程如下:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
        UncaughtExceptionHandler handler;
        wt.setDaemon(true);                           // configure thread
        if ((handler = ueh) != null)
            wt.setUncaughtExceptionHandler(handler);
        WorkQueue w = new WorkQueue(this, wt);
        int i = 0;                                    // assign a pool index
        int mode = config & MODE_MASK;
        int rs = lockRunState();
        try {
            WorkQueue[] ws; int n;                    // skip if no array
            if ((ws = workQueues) != null && (n = ws.length) > 0) {
                int s = indexSeed += SEED_INCREMENT;  // unlikely to collide
                int m = n - 1;
                //保证i是奇数 | 1
                i = ((s << 1) | 1) & m;               // odd-numbered indices
                if (ws[i] != null) {                  // collision
                    int probes = 0;                   // step by approx half n
                    int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
                    while (ws[i = (i + step) & m] != null) {
                        if (++probes >= n) {
                            workQueues = ws = Arrays.copyOf(ws, n <<= 1);
                            m = n - 1;
                            probes = 0;
                        }
                    }
                }
                w.hint = s;                           // use as random seed
                w.config = i | mode;
                //笔者注:下标赋值scanState,因为worker线程最大值0x7fff+1,workQueues的size是其2倍,所以下标最大值0xffff(16bits) 
                w.scanState = i;                      // publication fence
                ws[i] = w;
            }
        } finally {
            unlockRunState(rs, rs & ~RSLOCK);
        }
        wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
        return w;
    }
    ......
}

注册实际上就是创建一个WorkQueue对象维护起来,这里不多介绍,就注意一点worker线程本地队列一定是ForkJoinPool.workQueues的奇数下标的元素。“i = ((s << 1) | 1) & m” 中位运算"| 1"就保证i是奇数,后面递增的step是2的整数次方,所以i = i + step还是奇数。同时,下标索引赋值给WorkQueue的scanState变量,因为worker线程最大值0x7fff+1,workQueues的size是其2倍,所以下标最大值0xffff(16bits)。

    work线程创建完成直接start:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            if (fac != null && (wt = fac.newThread(this)) != null) {
                //笔者注:启动工作线程
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        deregisterWorker(wt, ex);
        return false;
    }
    ......
}

线程启动执行ForkJoinWorkerThread.run,然后又会调用ForkJoinPool.runWorker

public class ForkJoinWorkerThread extends Thread {
    ......
    public void run() {
        if (workQueue.array == null) { // only run once
            Throwable exception = null;
            try {
                onStart();
                pool.runWorker(workQueue);
            } catch (Throwable ex) {
                exception = ex;
            } finally {
                ......
            }
        }
    }
    ......
}


public class ForkJoinPool extends AbstractExecutorService {
    ......
    final void runWorker(WorkQueue w) {
        //笔者注:扩容w的任务队列,此处第一次调用相当于初始化
        w.growArray();                   // allocate queue
        int seed = w.hint;               // initially holds randomization hint
        int r = (seed == 0) ? 1 : seed;  // avoid 0 for xorShift
        for (ForkJoinTask<?> t;;) {
            //笔者注:扫描窃取任务,如果窃取到则执行任务
            if ((t = scan(w, r)) != null)
                w.runTask(t);
            else if (!awaitWork(w, r))
                break;
            r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
        }
    }
    ......
}

runWorker是最顶层的循环,线程一直循环直到终止。runWorker调用scan方法,scan就是工作窃取的核心实现,源码如下:

public class ForkJoinPool extends AbstractExecutorService {
    ......
    private ForkJoinTask<?> scan(WorkQueue w, int r) {
        WorkQueue[] ws; int m;
        if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
            int ss = w.scanState;                     // initially non-negative
            for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
                WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
                int b, n; long c;
                if ((q = ws[k]) != null) {
                    if ((n = (b = q.base) - q.top) < 0 &&
                        (a = q.array) != null) {      // non-empty
                        long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                        if ((t = ((ForkJoinTask<?>)
                                  U.getObjectVolatile(a, i))) != null &&
                            q.base == b) {
                            if (ss >= 0) {
                   //笔者注:任务存在,竞争窃取任务,竞争成功返回窃取成果,如果队列中任务还有剩,唤醒非激活线程
                                if (U.compareAndSwapObject(a, i, t, null)) {
                                    q.base = b + 1;
                                    if (n < -1)       // signal others
                                        signalWork(ws, q);
                                    return t;
                                }
                            }
                  //笔者注:如果当前线程 未激活,尝试激活非激活链的top线程(如果oldSum不为0则再扫描一次并且会将oldSum设置为0)
                            else if (oldSum == 0 &&   // try to activate
                                     w.scanState < 0)
                                tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                        }
                        if (ss < 0)                   // refresh
                            ss = w.scanState;
                        r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                        origin = k = r & m;           // move and rescan
                        oldSum = checkSum = 0;
                        continue;
                    }
                    checkSum += b;
                }
                //笔者注:又扫描了一圈
                if ((k = (k + 1) & m) == origin) {    // continue until stable
                //笔者注:如果checSum==oldSum 说明连续2圈扫描过程中各队列的base都没变(即没有任务),
                //设置线程未激活。然后继续扫描,如果第3圈扫描到任务,则尝试激活线程,如果第3圈还            
                //没扫到且checkSum又没变,那就返回null,尝试让线程park。
                    if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                        oldSum == (oldSum = checkSum)) { 
                        if (ss < 0 || w.qlock < 0)    // already inactive
                            break;
                        int ns = ss | INACTIVE;       // try to inactivate
                        long nc = ((SP_MASK & ns) |
                                   (UC_MASK & ((c = ctl) - AC_UNIT)));
                        w.stackPred = (int)c;         // hold prev stack top
                        U.putInt(w, QSCANSTATE, ns);
                        if (U.compareAndSwapLong(this, CTL, c, nc))
                            ss = ns;
                        else
                            w.scanState = ss;         // back out
                    }
                    checkSum = 0;
                }
            }
        }
        return null;
    }
    ......
}
  • (1)扫描到任务且线程是激活的,并且竞争窃取任务(其它线程也可能扫描到)。竞争成功返回窃取成果,竞争失败继续扫描;
  • (2)扫描到任务但线程是未激活的(可能此前连续2圈都没扫描到任务),则尝试release非激活线程链的top线程(如果oldSum==0,继续扫描并将oldSum、checkSum设置为0以便下次能够尝试release线程);
  • (3)如果连续扫描pool.workQueues 两圈都没有任务,oldSum==checkSum 说明连续两圈各队列的base值相同(没有被窃取过)、且队列中当前没有任务,那么设置线程未激活。继续扫描第3圈:
    • 第3圈扫描到任务,就会进入(2) 因为已经线程已经设置为未激活了
    • 第3圈未扫描到任务,如果checkSum还是未变,跳出循环返回null,让线程park;如果checkSum变了,继续第4圈重复第3圈的过程(此时oldSum已经是第3圈的checkSum)。
    • 如果返回null,runWorker会尝试park当前线程。

2.4 获取结果

    ForkJoinTask是Future的实现类,所以get方法可以获取结果。不过,作为ForkJoin框架的任务,有自己独有的方法-join()

public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
    ......
    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

    
    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

private int externalAwaitDone() {
        int s = ((this instanceof CountedCompleter) ? // try helping
                 ForkJoinPool.common.externalHelpComplete(
                     (CountedCompleter<?>)this, 0) :
                 ForkJoinPool.common.tryExternalUnpush(this) ? doExec() : 0);
        if (s >= 0 && (s = status) >= 0) {
            boolean interrupted = false;
            do {
                //笔者注:设置当前任务的信号bit位,当任务结束时就会调用notifyAll
                if (U.compareAndSwapInt(this, STATUS, s, s | SIGNAL)) {
                    synchronized (this) {
                        if (status >= 0) {
                            try {
                                //等待该任务对象notify
                                wait(0L);
                            } catch (InterruptedException ie) {
                                interrupted = true;
                            }
                        }
                        else
                            notifyAll();
                    }
                }
            } while ((s = status) >= 0);
            if (interrupted)
                Thread.currentThread().interrupt();
        }
        return s;
    }


    private int setCompletion(int completion) {
        for (int s;;) {
            if ((s = status) < 0)
                return s;
            if (U.compareAndSwapInt(this, STATUS, s, s | completion)) {
                //笔者注:如果状态的信号位(第17位)是1,则notifyAll唤醒所有等待结果的线程
                if ((s >>> 16) != 0)
                    synchronized (this) { notifyAll(); }
                return completion;
            }
        }
    }
    ......
}

    join方法最终调用externalAwaitDone方法,externalAwaitDone设置任务的信号位为1,然后调用wait等待task对象唤醒。当任务完成时调用setCompletion方法。setCompletion判断如果信号位为1(说明有线程等待结果),则唤醒线程(notifyAll)。get方法和join方法差别不大,主要是对异常的处理和线程中断的处理不一样,此处不做赘述。

    好记性不如烂笔头,随时记录当下的理解!

显示全文