写点什么

快速掌握并发编程 ---CountDownLatch 原理和实战

用户头像
田维常
关注
发布于: 2020 年 11 月 02 日

关注Java 后端技术全栈”**


回复“000”获取大量电子书


常见面试题


如何实现让主线程等所有子线程执行完了后,主要线程再继续执行?即如何实现一个线程等其他线程执行完了后再继续执行?


方法一


在前面的文章中我们介绍了 Thread 类的 join 方法:快速掌握并发编程---Thread常用方法,join 的工作原理是,不停检查 thread 是否存活,如果存活则让当前线程永远 wait,直到 thread 线程终止,线程的 notifyAll 就会被调用。


下面我们就使用 join 来实现上面面试题。


import java.util.Random;import java.util.concurrent.CountDownLatch;public class CountDownLatchDemo {    public static void main(String[] args) {        System.out.println("主要线程开始等待其他子线程执行");        try {            test();        } catch (InterruptedException e) {            e.printStackTrace();        }    }    public static void test() throws InterruptedException {       Thread thread1 = new Thread(() -> {            System.out.println(Thread.currentThread().getName() + " 线程开始");            Random random = new Random();            try {                Thread.sleep(random.nextInt(10000) + 1000);            } catch (InterruptedException e) {                e.printStackTrace();            }            System.out.println( Thread.currentThread().getName() + " 线程执行完毕");        },"线程1");       thread1.start();        Thread thread2 = new Thread(() -> {            System.out.println(Thread.currentThread().getName() + " 线程开始");            Random random = new Random();            try {                Thread.sleep(random.nextInt(10000) + 1000);            } catch (InterruptedException e) {                e.printStackTrace();            }            System.out.println(Thread.currentThread().getName() + " 线程执行完毕");        },"线程2");        thread2.start();        Thread thread3 = new Thread(() -> {            System.out.println(Thread.currentThread().getName() + " 线程开始");            Random random = new Random();            try {                Thread.sleep(random.nextInt(10000) + 1000);            } catch (InterruptedException e) {                e.printStackTrace();            }            System.out.println( Thread.currentThread().getName() + " 线程执行完毕");        },"线程3");        thread3.start();        Thread thread4 = new Thread(() -> {            System.out.println(Thread.currentThread().getName() + " 线程开始");            Random random = new Random();            try {                Thread.sleep(random.nextInt(10000) + 1000);            } catch (InterruptedException e) {                e.printStackTrace();            }            System.out.println(Thread.currentThread().getName() + " 线程执行完毕");        },"线程4");        thread4.start();        //启动了四个线程,然后让四个线程一直检测自己是否已经结束        thread1.join();        thread2.join();        thread3.join();        thread4.join();        System.out.println("主线程继续执行");        //todo 业务代码    }}
复制代码


运行结果


主要线程开始等待其他子线程执行线程1 线程开始线程2 线程开始线程3 线程开始线程4 线程开始线程3 线程执行完毕线程2 线程执行完毕线程1 线程执行完毕线程4 线程执行完毕主线程继续执行
复制代码


主线程继续干活是要等前面四个线程全部执行完毕后再继续的。但是这么搞有点麻烦,那就是每个线程都得调用 join 方法,有没有更好玩的的呢?


答案是有的,它来了。


它就是 juc 下面的一个很牛逼的并发工具类CountDownLatch。是 JDK1.5 的时候有的,言外之意就是在 JDK1.5 之前就只能用 join 方法了。


方法二


CountDownLatch 中我们主要用到两个方法一个是 await()方法,调用这个方法的线程会被阻塞,另外一个是 countDown()方法,调用这个方法会使计数器减一,当计数器的值为 0 时,因调用 await()方法被阻塞的线程会被唤醒,继续执行。请看代码:


import java.util.Random;import java.util.concurrent.CountDownLatch;public class CountDownLatchDemo {    public static void main(String[] args) {        System.out.println("主要线程开始等待其他子线程执行");        test();    }    public static void test() {        int threadCount = 5;        CountDownLatch countDownLatch = new CountDownLatch(threadCount);        for (int i = 0; i < threadCount; i++) {            final int finalI = i + 1;            new Thread(() -> {                System.out.println("第 " + finalI + " 线程开始");                Random random = new Random();                try {                    Thread.sleep(random.nextInt(10000) + 1000);                } catch (InterruptedException e) {                    e.printStackTrace();                }                System.out.println("第 " + finalI + " 线程执行完毕");                countDownLatch.countDown();            }).start();        }        try {            countDownLatch.await();        } catch (InterruptedException e) {            e.printStackTrace();        }        System.out.println(threadCount + " 个线程全部执行完毕");        System.out.println("主线程继续执行");        //todo业务代码    }}
复制代码


输出


主要线程开始等待其他子线程执行第 1 线程开始第 2 线程开始第 3 线程开始第 4 线程开始第 5 线程开始第 1 线程执行完毕第 2 线程执行完毕第 5 线程执行完毕第 4 线程执行完毕第 3 线程执行完毕5 个线程全部执行完毕主线程继续执行
复制代码


面试能把这两种方式说出来,证明你还是可以解决这个问题。


但问题来了,如果面试官问你实现原理,你却回答不上来,就会给人你在瞎用的感觉,这样好不容易前面拿到点好印象结果被打回原型。


至于 join 的原理,建议去看看我之前发的线程常用方法里:快速掌握并发编程---Thread常用方法,那里面说的很清楚了,所这里就不在重复了。


今天我们着重了了 CountDownLatch。


CountDownLatch


概念


CountDownLatch 可以使一个获多个线程等待其他线程各自执行完毕后再执行。


CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为 0 之前,阻塞队列里面的线程处于挂起状态,当计数器递减到 0 时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch 可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。


整体



常用方法


构造方法


我们在上面的案例中


 int threadCount = 5; CountDownLatch countDownLatch = new CountDownLatch(threadCount);
复制代码


有用到 new CountDownLatch(threadCount);来创建一个 CountDownLatch 实例对象。我们看看这个构造方法


private final Sync sync;public CountDownLatch(int count) {     //记者count值不能小于0    if (count < 0) throw new IllegalArgumentException("count < 0");    //创建一个Sync实例对象入参就是count    this.sync = new Sync(count);}
复制代码


然后这里有个内部类 Sync,这个 Sync 内部类也没几行代码,Sync 继承了 AbstractQueuedSynchronizer 抽象队列同步器(以下简称 AQS)。


private static final class Sync extends AbstractQueuedSynchronizer {        private static final long serialVersionUID = 4982264981922014374L;        //入参count        Sync(int count) {            //这个setState方法还记得否?就是上篇文章AQS中的setState()方法            //就是给AQS中的state赋值,state=count            setState(count);        }        //获取AQS中state的值        int getCount() {            return getState();        }        protected int tryAcquireShared(int acquires) {            return (getState() == 0) ? 1 : -1;        }        //死循环        protected boolean tryReleaseShared(int releases) {            for (;;) {                //获取AQS中的state                int c = getState();                //如果AQS中的state==0,就返回false                if (c == 0)  return false;                int nextc = c-1;                //nextc=state-1                //                if (compareAndSetState(c, nextc))                    return nextc == 0;            }        } }
复制代码


countDown 方法


public void countDown() {    //调用的就是AQS中的方法    sync.releaseShared(1);}
复制代码


AQS 中 releaseShared 方法


public final boolean releaseShared(int arg) {    // arg 为固定值 1    // 如果计数器state 为0 返回true,前提是调用 countDown() 之前不能已经为0    //tryReleaseShared在AQS是空方法    if (tryReleaseShared(arg)) {      // 唤醒等待队列的线程        doReleaseShared();         return true;    }    return false;}protected boolean tryReleaseShared(int arg) {   throw new UnsupportedOperationException();}
复制代码


这个方法 tryReleaseShared()是在 CountDownLatch 中内部类 Sync 中实现的


//其实也没什么新招//还是死循环+CAS配合 实现计数器state减1protected boolean tryReleaseShared(int releases) {    // Decrement count; signal when transition to zero    for (;;) {        int c = getState();        if (c == 0)  return false;        int nextc = c-1;        if (compareAndSetState(c, nextc))           return nextc == 0;     }}
复制代码


方法 doReleaseShared 却是 AQS 种实现的(因为 CountDownLatch 和其内部类都没有实现,只有 AQS 实现了,那就只认 AQS 中的实现了)。


//实现思路就是从头到尾的遍历列队中所有的节点为shared状态的private void doReleaseShared() {        //死循环        for (;;) {            //获取当前列队的头节点            Node h = head;            //列队中可能为空列队,也有可能只有一个node节点            if (h != null && h != tail) {                //获取头节点的状态                int ws = h.waitStatus;                //如果头节点为SIGNAL状态, 说明后继节点需要唤醒                if (ws == Node.SIGNAL) {                    //将头结点的waitstatue设置为0,以后就不会再次唤醒后继节点了。                    //这一步是为了解决并发问题,保证只unpark一次!!不成功就继续                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))                        continue;            // loop to recheck cases                    //(释放)唤醒头节点的后继节点                    unparkSuccessor(h);                }// 状态为0并且不成功,继续                else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))                    continue;// loop on failed CAS            }            // 若头结点改变,继续循环             if (h == head) // loop if head changed                break;        }}
复制代码


整个调用逻辑大致为



await 方法


在 CountDownLatch 中 await 犯法


public void await() throws InterruptedException {   sync.acquireSharedInterruptibly(1);}
复制代码


然后调用 AQS 中的


public final void acquireSharedInterruptibly(int arg) throws InterruptedException {      //判断是否被中断过      if (Thread.interrupted()) throw new InterruptedException();      //如果state不等于0的时候      if (tryAcquireShared(arg) < 0){            doAcquireSharedInterruptibly(arg);      }}
复制代码


其中方法 tryAcquireShared(arg)是 CountDownLatch 的内部类 Sync 的 tryAcquireShared 方法


protected int tryAcquireShared(int acquires) {  //判断AQS中的state是否已经等于0了,等于翻译1否则返回-1  return (getState() == 0) ? 1 : -1;}
复制代码


再调用 AQS 中的 doAcquireSharedInterruptibly 方法


 //这个方法就是将当前线程封装成node节点加入到列队中,并判断是否需要阻塞当前线程 //这个节点都会被设置成shared状态,这样做的目的时当state值为0时会唤醒所有shared的节点private void doAcquireSharedInterruptibly(int arg)        throws InterruptedException {        //这个方法应该很熟悉了吧,前面的文章都介绍过,将当前线程封装成节点加入到列队中        final Node node = addWaiter(Node.SHARED);        boolean failed = true;        try {            //(又是死循环)一直执行,直到获取锁,返回            for (;;) {                //获取前驱节点                final Node p = node.predecessor();                //前驱节点为头结点                if (p == head) {                    //所以再次尝试获取信号量,这就是上面分析的那个获取方法                    int r = tryAcquireShared(arg);                    //如果r大于0证明获取信号量获取成功了证明下一个可以获取信号量的线程是当前线程                    if (r >= 0) {                        //将当前节点变成列队的head节点然后返回                        setHeadAndPropagate(node, r);                        //方便GC                        p.next = null;                         failed = false;                        return;                    }                }               //判断是否要进入阻塞状态.如果shouldParkAfterFailedAcquire方法               //返回true,表示需要进入阻塞               //调用parkAndCheckInterrupt;否则表示还可以再次尝试获取锁,继续进行for循环                if (shouldParkAfterFailedAcquire(p, node) &&                    parkAndCheckInterrupt())                    throw new InterruptedException();            }        } finally {            //失败就放弃            if (failed){                cancelAcquire(node);            }        }}
复制代码


方法 shouldParkAfterFailedAcquire 是 AQS 的


//p是前驱结点,node是当前结点private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {    int ws = pred.waitStatus; //获取前驱节点的状态    if (ws == Node.SIGNAL) //表明前驱节点可以运行        return true;    if (ws > 0) { //如果前驱节点状态大于0表明已经中断,        do {            node.prev = pred = pred.prev;         } while (pred.waitStatus > 0);        pred.next = node;    } else {        //等于0进入这里        compareAndSetWaitStatus(pred, ws, Node.SIGNAL);     }    //只有前节点状态为NodeSIGNAL才返回真    return false; }
复制代码


我们对 shouldParkAfterFailedAcquire 来进行一个整体的概述,首先应该明白节点的状态,节点的状态是为了表明当前线程的良好度,如果当前线程被打断了,在唤醒的过程中是不是应该忽略该线程


 static final class Node {        static final int CANCELLED =  1;        static final int SIGNAL    = -1;        static final int CONDITION = -2;        static final int PROPAGATE = -3;       //....
复制代码


目前你只需知道大于 0 时表明该线程已近被取消,已近是无效节点,不应该被唤醒,注意:初始化链头节点时头节点状态值为 0。


shouldParkAfterFailedAcquire 是位于无限 for 循环内的,这一点需要注意一般每个节点都会经历两次循环后然后被阻塞。


在 AQS 的 doAcquireSharedInterruptibly 中可能会再次调用 CountDownLatch 的内部类 Sync 的 tryAcquireShared 方法和 AQS 的 setHeadAndPropagate 方法。setHeadAndPropagate 方法源码如下。


private void setHeadAndPropagate(Node node, int propagate) {        // 获取头结点        Node h = head;         // 设置头结点        setHead(node);        // 进行判断        if (propagate > 0 || h == null || h.waitStatus < 0 ||            (h = head) == null || h.waitStatus < 0) {            // 获取节点的后继            Node s = node.next;            if (s == null || s.isShared()) // 后继为空或者为共享模式                // 以共享模式进行释放                doReleaseShared();        }    }
复制代码


该方法设置头结点并且释放头结点后面的满足条件的结点,该方法中可能会调用到 AQS 的 doReleaseShared 方法,其源码如下。


private void doReleaseShared() {        // 无限循环        for (;;) {            // 保存头结点            Node h = head;            if (h != null && h != tail) { // 头结点不为空并且头结点不为尾结点                // 获取头结点的等待状态                int ws = h.waitStatus;                 if (ws == Node.SIGNAL) { // 状态为SIGNAL                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续                        continue;            // loop to recheck cases                    // 释放后继结点                    unparkSuccessor(h);                }                else if (ws == 0 &&                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续                    continue;                // loop on failed CAS            }            if (h == head) // 若头结点改变,继续循环                 break;        }    }
复制代码


CountDownLatch 的 await 调用大致会有如下的调用链



经典使用场景


CountDownLatch 的一个非常典型的应用场景是:有一个任务想要往下执行,但必须要等到其他的任务执行完毕后才可以继续往下执行。假如我们这个想要继续往下执行的任务调用一个 CountDownLatch 对象的 await()方法,其他的任务执行完自己的任务后调用同一个 CountDownLatch 对象上的 countDown()方法,这个调用 await()方法的任务将一直阻塞等待,直到这个 CountDownLatch 对象的计数值减到 0 为止。


案例 1


举个例子,有三个工人在为老板干活,这个老板有一个习惯,就是当三个工人把一天的活都干完了的时候,他就来检查所有工人所干的活。记住这个条件:三个工人先全部干完活,老板才检查。


案例 2


比如读取 excel 表格,需要把 execl 表格中的多个 sheet 进行数据汇总,为了提高汇总的效率我们一般会开启多个线程,每个线程去读取一个 sheet,可是线程执行是异步的,我们不知道什么时候数据处理结束了。那么这个时候我们就可以运用 CountDownLatch,有几个 sheet 就把 state 初始化几。每个线程执行完就调用 countDown()方法,在汇总的地方加上 await()方法,当所有线程执行完了,就可以进行数据的汇总了。


END


扫描关注公众号“Java 后端技术全栈”


解锁程序员的狂野世界



发布于: 2020 年 11 月 02 日阅读数: 49
用户头像

田维常

关注

关注公众号:Java后端技术全栈,领500G资料 2020.10.24 加入

关注公众号:Java后端技术全栈,领500G资料

评论

发布
暂无评论
快速掌握并发编程---CountDownLatch原理和实战